aGrUM  0.20.3
a C++ library for (probabilistic) graphical models
O3prmrInterpreter.cpp
Go to the documentation of this file.
1 /**
2  *
3  * Copyright (c) 2005-2021 by Pierre-Henri WUILLEMIN(@LIP6) & Christophe GONZALES(@AMU)
4  * info_at_agrum_dot_org
5  *
6  * This library is free software: you can redistribute it and/or modify
7  * it under the terms of the GNU Lesser General Public License as published by
8  * the Free Software Foundation, either version 3 of the License, or
9  * (at your option) any later version.
10  *
11  * This library is distributed in the hope that it will be useful,
12  * but WITHOUT ANY WARRANTY; without even the implied warranty of
13  * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  * GNU Lesser General Public License for more details.
15  *
16  * You should have received a copy of the GNU Lesser General Public License
17  * along with this library. If not, see <http://www.gnu.org/licenses/>.
18  *
19  */
20 
21 
22 /**
23  * @file
24  * @brief Implementation of O3prmReader<double>.
25  *
26  * @author Pierre-Henri WUILLEMIN(@LIP6), Ni NI, Lionel TORTI & Vincent RENAUDINEAU
27  */
28 
29 #include <agrum/agrum.h>
30 
31 #include <agrum/BN/BayesNet.h>
32 #include <agrum/BN/inference/lazyPropagation.h>
33 #include <agrum/BN/inference/tools/BayesNetInference.h>
34 #include <agrum/BN/inference/variableElimination.h>
35 
36 #include <agrum/PRM/inference/SVE.h>
37 #include <agrum/PRM/inference/SVED.h>
38 #include <agrum/PRM/inference/groundedInference.h>
39 #include <agrum/PRM/o3prmr/O3prmrInterpreter.h>
40 
41 #include <agrum/PRM/o3prmr/cocoR/Parser.h>
42 
43 namespace gum {
44 
45  namespace prm {
46 
47  namespace o3prmr {
48 
49  /* **************************************************************************
50  */
51 
52  /// This constructor create an empty context.
54  m_context(new O3prmrContext< double >()), m_reader(new o3prm::O3prmReader< double >()),
55  m_bn(0), m_inf(0), m_syntax_flag(false), m_verbose(false), m_log(std::cout),
56  m_current_line(-1) {}
57 
58  /// Destructor. Delete current context.
60  delete m_context;
61  if (m_bn) { delete m_bn; }
62  for (auto p: m_inf_map) {
63  delete p.second;
64  }
65  delete m_reader->prm();
66  delete m_reader;
67  }
68 
69  /* **************************************************************************
70  */
71 
72  /// Getter for the context.
73  O3prmrContext< double >* O3prmrInterpreter::getContext() const { return m_context; }
74 
75  /// Setter for the context.
77  delete m_context;
78 
79  if (context == 0)
80  m_context = new O3prmrContext< double >();
81  else
83  }
84 
85  /// Root paths to search from there packages.
86  /// Default are './' and one is calculate from request package if any.
88 
89  /// Root paths to search from there packages.
90  /// Default are './' and one is calculate from request package if any.
92  if (path.length() && path.back() != '/') { path = path + '/'; }
93  if (Directory::isDir(path)) {
95  } else {
96  GUM_ERROR(NotFound, "not a directory")
97  }
98  }
99 
100  /// Root paths to search from there packages.
101  /// Default are './' and one is calculate from request package if any.
103 
104  /// syntax mode don't process anything, just check syntax.
106 
107  /// syntax mode don't process anything, just check syntax.
109 
110  /// verbose mode show more details on the program execution.
111  bool O3prmrInterpreter::isVerboseMode() const { return m_verbose; }
112 
113  /// verbose mode show more details on the program execution.
115 
116  /// Retrieve prm object.
117  const PRM< double >* O3prmrInterpreter::prm() const { return m_reader->prm(); }
118 
119  /// Retrieve inference motor object.
120  const PRMInference< double >* O3prmrInterpreter::inference() const { return m_inf; }
121 
122  /// Return a std::vector of QueryResults.
123  /// Each QueryResults is a struct with query command, time and values,
124  /// a std::vector of struct SingleResult, with pair label/value.
125  const std::vector< QueryResult >& O3prmrInterpreter::results() const { return m_results; }
126 
127  /**
128  * Parse the file or the command line.
129  * If errors occured, return false. Errors messages can be retrieve be
130  * getErrorsContainer() methods.
131  * If any errors occured, return true.
132  * Requests results can be retrieve be results() methods.
133  * */
135  m_results.clear();
136 
137  try {
139 
140  delete m_context;
141  m_context = new O3prmrContext< double >(filename);
142  O3prmrContext< double > c(filename);
143 
144  // On vérifie la syntaxe
145  unsigned char* buffer = new unsigned char[file_content.length() + 1];
146  strcpy((char*)buffer, file_content.c_str());
147  Scanner s(buffer, int(file_content.length() + 1));
148  Parser p(&s);
149  p.setO3prmrContext(&c);
150  p.Parse();
151 
152  m_errors = p.errors();
153 
154  if (errors() > 0) { return false; }
155 
156  // Set paths to search from.
157  delete m_reader->prm();
158  delete m_reader;
159  m_reader = new o3prm::O3prmReader< double >();
160 
161  for (size_t i = 0; i < m_paths.size(); i++) {
163  }
164 
165  // On vérifie la sémantique.
166  if (!checkSemantic(&c)) { return false; }
167 
168  if (isInSyntaxMode()) {
169  return true;
170  } else {
171  return interpret(&c);
172  }
173  } catch (gum::Exception&) { return false; }
174  }
175 
177  // read entire file into string
179  if (istream) {
180  // get length of file:
181  istream.seekg(0, istream.end);
182  int length = int(istream.tellg());
183  istream.seekg(0, istream.beg);
184 
185  std::string str;
186  str.resize(length, ' '); // reserve space
187  char* begin = &*str.begin();
188 
190  istream.close();
191 
192  return str;
193  }
194  GUM_ERROR(OperationNotAllowed, "Could not open file")
195  }
196 
198  m_results.clear();
199 
200  // On vérifie la syntaxe
201  O3prmrContext< double > c;
202  Scanner s((unsigned char*)line.c_str(), (int)line.length());
203  Parser p(&s);
204  p.setO3prmrContext(&c);
205  p.Parse();
206  m_errors = p.errors();
207 
208  if (errors() > 0) return false;
209 
210  // On vérifie la sémantique.
211  if (!checkSemantic(&c)) return false;
212 
213  if (isInSyntaxMode())
214  return true;
215  else
216  return interpret(&c);
217  }
218 
219  /**
220  * Crée le prm correspondant au contexte courant.
221  * Renvoie true en cas de succès, ou false en cas échéant d'échec
222  * de l'interprétation du contexte (import introuvable ou non défini,
223  * etc).
224  * */
226  if (isVerboseMode()) m_log << "## Start interpretation." << std::endl << std::flush;
227 
228  // Don't parse if any syntax errors.
229  if (errors() > 0) return false;
230 
231  // For each session
232  std::vector< O3prmrSession< double >* > sessions = c->sessions();
233 
234  for (const auto session: sessions)
235  for (auto command: session->commands()) {
236  // We process it.
237  bool result = true;
238 
239  try {
240  switch (command->type()) {
242  result = observe((ObserveCommand< double >*)command);
243  break;
244 
246  result = unobserve((UnobserveCommand< double >*)command);
247  break;
248 
251  break;
252 
255  break;
256 
258  query((QueryCommand< double >*)command);
259  break;
260  }
261  } catch (Exception& err) {
262  result = false;
264  } catch (std::string& err) {
265  result = false;
266  addError(err);
267  }
268 
269  // If there was a problem, skip the rest of this session,
270  // unless syntax mode is activated.
271  if (!result) {
272  if (m_verbose) m_log << "Errors : skip the rest of this session." << std::endl;
273 
274  break;
275  }
276  }
277 
278  if (isVerboseMode()) m_log << "## End interpretation." << std::endl << std::flush;
279 
280  return errors() == 0;
281  }
282 
283  /* **************************************************************************
284  */
285 
286  /**
287  * Check semantic validity of context.
288  * Import first all import, and check that systems, instances, attributes
289  *and
290  *labels exists.
291  * While checking, prepare data structures for interpretation.
292  * Return true if all is right, false otherwise.
293  *
294  * Note : Stop checking at first error unless syntax mode is activated.
295  * */
297  // Don't parse if any syntax errors.
298  if (errors() > 0) return false;
299 
300  // On importe tous les systèmes.
301  for (const auto command: context->imports()) {
303  // if import doen't succed stop here unless syntax mode is activated.
304  bool succeed = import(context, command->value);
305 
306  if (!succeed && !isInSyntaxMode()) return false;
307 
308  // En cas de succès, on met à jour le contexte global
310  }
311 
312  if (m_verbose)
313  m_log << "## Check semantic for " << context->sessions().size() << " sessions"
314  << std::endl;
315 
316  // On vérifie chaque session
317  for (const auto session: context->sessions()) {
319  O3prmrSession< double >* new_session = new O3prmrSession< double >(sessionName);
320 
321  if (m_verbose)
322  m_log << "## Start session '" << sessionName << "'..." << std::endl << std::endl;
323 
324  for (const auto command: session->commands()) {
325  if (m_verbose)
326  m_log << "# * Going to check command : " << command->toString() << std::endl;
327 
328  // Update the current line (for warnings and errors)
330 
331  // We check it.
332  bool result = true;
333 
334  try {
335  switch (command->type()) {
338  break;
339 
342  break;
343 
345  result = checkObserve((ObserveCommand< double >*)command);
346  break;
347 
350  break;
351 
353  result = checkQuery((QueryCommand< double >*)command);
354  break;
355 
356  default:
357  addError("Error : Unknow command : " + command->toString()
358  + "\n -> Command not processed.");
359  result = false;
360  }
361  } catch (Exception& err) {
362  result = false;
364  } catch (std::string& err) {
365  result = false;
366  addError(err);
367  }
368 
369  // If there was a problem, skip the rest of this session,
370  // unless syntax mode is activated.
371  if (!result && !isInSyntaxMode()) {
372  if (m_verbose) m_log << "Errors : skip the rest of this session." << std::endl;
373 
374  break;
375  }
376 
377  // On l'ajoute au contexte globale
379  }
380 
381  // Ajoute la session au contexte global,
382  // ou à la dernière session.
383  if (sessionName == "default" && m_context->sessions().size() > 0)
384  *(m_context->sessions().back()) += *new_session;
385  else
387 
388  if (m_verbose)
389  m_log << std::endl
390  << "## Session '" << sessionName << "' finished." << std::endl
391  << std::endl
392  << std::endl;
393 
394  // todo : check memory leak
395  // delete new_session; ??
396  }
397 
399 
400  return errors() == 0;
401  }
402 
405  return m_engine == "SVED" || m_engine == "GRD" || m_engine == "SVE";
406  }
407 
410  return m_bn_engine == "VE" || m_bn_engine == "VEBB" || m_bn_engine == "lazy";
411  }
412 
414  try {
417 
418  // Contruct the pair (instance,attribut)
419  const PRMSystem< double >& sys = system(left_val);
420  const PRMInstance< double >& instance = sys.get(findInstanceName(left_val, sys));
422  typename PRMInference< double >::Chain chain = std::make_pair(&instance, &attr);
423 
424  command->system = &sys;
426 
427  // Check label exists for this type.
428  // Potential<double> e;
431  bool found = false;
432 
433  for (i.setFirst(); !i.end(); i.inc()) {
435  == right_val) {
436  command->potentiel.set(i, (double)1.0);
437  found = true;
438  } else {
439  command->potentiel.set(i, (double)0.0);
440  }
441  }
442 
443  if (!found) addError(right_val + " is not a label of " + left_val);
444 
445  // else command->potentiel = e;
446 
447  return found;
448 
449  } catch (Exception& err) { addError(err.errorContent()); } catch (std::string& err) {
450  addError(err);
451  }
452 
453  return false;
454  }
455 
457  try {
459 
460  // Contruct the pair (instance,attribut)
461  const PRMSystem< double >& sys = system(name);
462  const PRMInstance< double >& instance = sys.get(findInstanceName(name, sys));
464  // PRMInference<double>::Chain chain = std::make_pair(&instance,
465  // &attr);
466 
467  command->system = &sys;
469 
470  return true;
471 
472  } catch (Exception& err) { addError(err.errorContent()); } catch (std::string& err) {
473  addError(err);
474  }
475 
476  return false;
477  }
478 
480  try {
482 
483  // Contruct the pair (instance,attribut)
484  const PRMSystem< double >& sys = system(name);
485  const PRMInstance< double >& instance = sys.get(findInstanceName(name, sys));
487  // PRMInference<double>::Chain chain = std::make_pair(&instance,
488  // &attr);
489 
490  command->system = &sys;
492 
493  return true;
494 
495  } catch (Exception& err) { addError(err.errorContent()); } catch (std::string& err) {
496  addError(err);
497  }
498 
499  return false;
500  }
501 
502  // Import the system o3prm file
503  // Return false if any error.
504 
506  try {
507  if (m_verbose) { m_log << "# Loading system '" << import_name << "' => '" << std::flush; }
508 
510 
511  std::replace(import_name.begin(), import_name.end(), '.', '/');
512  import_name += ".o3prm";
513 
514  if (m_verbose) { m_log << import_name << "' ... " << std::endl << std::flush; }
515 
517  bool found = false;
519 
520  // Search in o3prmr file dir.
522 
523  if (!o3prmrFilename.empty()) {
525 
526  if (index != std::string::npos) {
529 
530  if (m_verbose) {
531  m_log << "# Search from filedir '" << import_abs_filename << "' ... " << std::flush;
532  }
533 
535 
536  if (file_test.is_open()) {
537  if (m_verbose) { m_log << "found !" << std::endl << std::flush; }
538 
539  file_test.close();
540  found = true;
541  } else if (m_verbose) {
542  m_log << "not found." << std::endl << std::flush;
543  }
544  }
545  }
546 
547  // Deduce root path from package name.
549 
550  if (!found && !package.empty()) {
551  std::string root;
552 
553  // if filename is not empty, start from it.
555 
556  if (!filename.empty()) {
558 
559  if (size != std::string::npos) {
560  root += filename.substr(0, size + 1); // take with the '/'
561  }
562  }
563 
564  //
565  root += "../";
566  int count = (int)std::count(package.begin(), package.end(), '.');
567 
568  for (int i = 0; i < count; i++)
569  root += "../";
570 
572 
573  if (m_verbose) {
574  m_log << "# Search from package '" << package << "' => '" << import_abs_filename
575  << "' ... " << std::flush;
576  }
577 
579 
580  if (file_test.is_open()) {
581  if (m_verbose) { m_log << "found !" << std::endl << std::flush; }
582 
583  file_test.close();
584  found = true;
585  } else if (m_verbose) {
586  m_log << "not found." << std::endl << std::flush;
587  }
588  }
589 
590  // Search import in all paths.
591  for (const auto& path: m_paths) {
593 
594  if (m_verbose) {
595  m_log << "# Search from classpath '" << import_abs_filename << "' ... " << std::flush;
596  }
597 
599 
600  if (file_test.is_open()) {
601  if (m_verbose) { m_log << " found !" << std::endl << std::flush; }
602 
603  file_test.close();
604  found = true;
605  break;
606  } else if (m_verbose) {
607  m_log << " not found." << std::endl << std::flush;
608  }
609  }
610 
611  if (!found) {
612  if (m_verbose) { m_log << "Finished with errors." << std::endl; }
613 
614  addError("import not found.");
615  return false;
616  }
617 
618  // May throw std::IOError if file does't exist
621 
622  try {
624 
625  // Show errors and warning
626  if (m_verbose
627  && (m_reader->errors() > (unsigned int)previousO3prmError
628  || errors() > previousO3prmrError)) {
629  m_log << "Finished with errors." << std::endl;
630  } else if (m_verbose) {
631  m_log << "Finished." << std::endl;
632  }
633 
634  } catch (const IOError& err) {
635  if (m_verbose) { m_log << "Finished with errors." << std::endl; }
636 
638  }
639 
640  // Add o3prm errors and warnings to o3prmr errors
643  }
644 
645  return errors() == previousO3prmrError;
646 
647  } catch (const Exception& err) {
648  if (m_verbose) { m_log << "Finished with exceptions." << std::endl; }
649 
651  return false;
652  }
653  }
654 
656  size_t dot = s.find_first_of('.');
657  std::string name = s.substr(0, dot);
658 
659  // We look first for real system, next for alias.
660  if (prm()->isSystem(name)) {
661  s = s.substr(dot + 1);
662  return name;
663  }
664 
665  if (!m_context->aliasToImport(name).empty()) {
666  s = s.substr(dot + 1);
667  return m_context->aliasToImport(name);
668  }
669 
670  while (dot != std::string::npos) {
671  if (prm()->isSystem(name)) {
672  s = s.substr(dot + 1);
673  return name;
674  }
675 
676  dot = s.find('.', dot + 1);
677  name = s.substr(0, dot);
678  }
679 
680  throw "could not find any system in '" + s + "'.";
681  }
682 
684  const PRMSystem< double >& sys) {
685  // We have found system before, so 's' has been stripped.
686  size_t dot = s.find_first_of('.');
687  std::string name = s.substr(0, dot);
688 
689  if (!sys.exists(name))
690  throw "'" + name + "' is not an instance of system '" + sys.name() + "'.";
691 
692  s = s.substr(dot + 1);
693  return name;
694  }
695 
697  const PRMInstance< double >& instance) {
698  if (!instance.exists(s))
699  throw "'" + s + "' is not an attribute of instance '" + instance.name() + "'.";
700 
701  return s;
702  }
703 
704  // After this method, ident doesn't contains the system name anymore.
706  try {
707  return prm()->getSystem(findSystemName(ident));
708  } catch (const std::string&) {}
709 
710  if ((m_context->mainImport() != 0) && prm()->isSystem(m_context->mainImport()->value))
711  return prm()->getSystem(m_context->mainImport()->value);
712 
713  throw "could not find any system or alias in '" + ident
714  + "' and no default alias has been set.";
715  }
716 
717  ///
718 
719  bool O3prmrInterpreter::observe(const ObserveCommand< double >* command) try {
720  const typename PRMInference< double >::Chain& chain = command->chain;
721 
722  // Generate the inference engine if it doesn't exist.
723  if (!m_inf) { generateInfEngine(*(command->system)); }
724 
725  // Prevent from something
726  if (m_inf->hasEvidence(chain)) addWarning(command->leftValue + " is already observed");
727 
729 
730  if (m_verbose)
731  m_log << "# Added evidence " << command->rightValue << " over attribute "
732  << command->leftValue << std::endl;
733 
734  return true;
735 
736  } catch (OperationNotAllowed& ex) {
737  addError("something went wrong when adding evidence " + command->rightValue + " over "
738  + command->leftValue + " : " + ex.errorContent());
739  return false;
740 
741  } catch (const std::string& msg) {
742  addError(msg);
743  return false;
744  }
745 
746  ///
747 
748  bool O3prmrInterpreter::unobserve(const UnobserveCommand< double >* command) try {
750  typename PRMInference< double >::Chain chain = command->chain;
751 
752  // Prevent from something
753  if (!m_inf || !m_inf->hasEvidence(chain)) {
754  addWarning(name + " was not observed");
755  } else {
757 
758  if (m_verbose) m_log << "# Removed evidence over attribute " << name << std::endl;
759  }
760 
761  return true;
762 
763  } catch (const std::string& msg) {
764  addError(msg);
765  return false;
766  }
767 
768  ///
769  void O3prmrInterpreter::query(const QueryCommand< double >* command) try {
770  const std::string& query = command->value;
771 
772  if (m_inf_map.exists(command->system)) {
774  } else {
775  m_inf = nullptr;
776  }
777 
778  // Create inference engine if it has not been already created.
779  if (!m_inf) { generateInfEngine(*(command->system)); }
780 
781  // Inference
782  if (m_verbose) {
783  m_log << "# Starting inference over query: " << query << "... " << std::endl;
784  }
785 
786  Timer timer;
787  timer.reset();
788 
789  Potential< double > m;
791 
792  // Compute spent time
793  double t = timer.step();
794 
795  if (m_verbose) { m_log << "Finished." << std::endl; }
796 
797  if (m_verbose) { m_log << "# Time in seconds (accuracy ~0.001): " << t << std::endl; }
798 
799  // Show results
800 
801  if (m_verbose) { m_log << std::endl; }
802 
804  result.command = query;
805  result.time = t;
806 
807  Instantiation j(m);
808  const PRMAttribute< double >& attr = *(command->chain.second);
809 
810  for (j.setFirst(); !j.end(); j.inc()) {
811  // auto label_value = j.val ( attr.type().variable() );
812  auto label_value = j.val(0);
814  float value = float(m.get(j));
815 
818  singleResult.p = value;
819 
821 
822  if (m_verbose) { m_log << label << " : " << value << std::endl; }
823  }
824 
826 
827  if (m_verbose) { m_log << std::endl; }
828 
829  } catch (Exception& e) {
830  GUM_SHOWERROR(e);
831  throw "something went wrong while infering: " + e.errorContent();
832 
833  } catch (const std::string& msg) { addError(msg); }
834 
835  ///
838  }
839 
840  ///
843  }
844 
845  ///
847  if (m_verbose) m_log << "# Building the inference engine... " << std::flush;
848 
849  //
850  if (m_engine == "SVED") {
851  m_inf = new SVED< double >(*(prm()), sys);
852 
853  //
854  } else if (m_engine == "SVE") {
855  m_inf = new SVE< double >(*(prm()), sys);
856 
857  } else {
858  if (m_engine != "GRD") {
859  addWarning("unkown engine '" + m_engine + "', use GRD insteed.");
860  }
861 
862  MarginalTargetedInference< double >* bn_inf = nullptr;
863  if (m_bn) { delete m_bn; }
864  m_bn = new BayesNet< double >();
865  BayesNetFactory< double > bn_factory(m_bn);
866 
867  if (m_verbose) m_log << "(Grounding the network... " << std::flush;
868 
870 
871  if (m_verbose) m_log << "Finished)" << std::flush;
872 
873  // bn_inf = new LazyPropagation<double>( *m_bn );
874  bn_inf = new VariableElimination< double >(m_bn);
875 
876  auto grd_inf = new GroundedInference< double >(*(prm()), sys);
878  m_inf = grd_inf;
879  }
880 
882  if (m_verbose) m_log << "Finished." << std::endl;
883  }
884 
885  /* **************************************************************************
886  */
887 
888  /// # of errors + warnings
889  Size O3prmrInterpreter::count() const { return m_errors.count(); }
890 
891  ///
893 
894  ///
896 
897  ///
899  if (i >= count()) throw "Index out of bound.";
900 
901  return m_errors.error(i);
902  }
903 
904  /// Return container with all errors.
906 
907  ///
910  }
911 
912  ///
915  }
916 
917  ///
920  }
921 
922  /* **************************************************************************
923  */
924 
925  ///
928 
929  if (m_verbose) m_log << m_errors.last().toString() << std::endl;
930  }
931 
932  ///
935 
936  if (m_verbose) m_log << m_errors.last().toString() << std::endl;
937  }
938 
939  } // namespace o3prmr
940  } // namespace prm
941 } // namespace gum
INLINE void emplace(Args &&... args)
Definition: set_tpl.h:643
ParamScopeData(const std::string &s, const PRMReferenceSlot< GUM_SCALAR > &ref, Idx d)