29 #ifndef DOXYGEN_SHOULD_SKIP_THIS 43 template <
typename GUM_SCALAR >
45 const IBayesNet< GUM_SCALAR >* BN,
47 bool use_binary_join_tree) :
48 JointTargetedInference< GUM_SCALAR >(BN),
49 EvidenceInference< GUM_SCALAR >(BN),
50 __use_binary_join_tree(use_binary_join_tree) {
63 template <
typename GUM_SCALAR >
67 for (
const auto pot : pots.second)
101 template <
typename GUM_SCALAR >
103 const Triangulation& new_triangulation) {
112 template <
typename GUM_SCALAR >
120 template <
typename GUM_SCALAR >
129 template <
typename GUM_SCALAR >
131 Potential< GUM_SCALAR >* (*proj)(
const Potential< GUM_SCALAR >&,
132 const Set< const DiscreteVariable* >&)) {
138 template <
typename GUM_SCALAR >
140 Potential< GUM_SCALAR >* (*comb)(
const Potential< GUM_SCALAR >&,
141 const Potential< GUM_SCALAR >&)) {
147 template <
typename GUM_SCALAR >
151 potset.second.clear();
153 mess_computed.second =
false;
156 for (
const auto& potset : __created_potentials)
157 for (
const auto pot : potset.second)
161 for (
const auto& pot : __target_posteriors)
163 for (
const auto& pot : __joint_target_posteriors)
173 template <
typename GUM_SCALAR >
185 "setFindBarrenNodesType for type " 186 << (
unsigned int)type <<
" is not implemented yet");
198 template <
typename GUM_SCALAR >
201 bool isHardEvidence) {
210 }
catch (DuplicateElement&) {
222 template <
typename GUM_SCALAR >
225 bool isHardEvidence) {
233 }
catch (DuplicateElement&) {
249 template <
typename GUM_SCALAR >
251 bool has_hard_evidence) {
258 }
catch (DuplicateElement&) {
275 template <
typename GUM_SCALAR >
277 const NodeId id,
bool hasChangedSoftHard) {
278 if (hasChangedSoftHard)
283 }
catch (DuplicateElement&) {
293 template <
typename GUM_SCALAR >
300 template <
typename GUM_SCALAR >
306 template <
typename GUM_SCALAR >
313 template <
typename GUM_SCALAR >
319 template <
typename GUM_SCALAR >
324 template <
typename GUM_SCALAR >
328 template <
typename GUM_SCALAR >
330 const IBayesNet< GUM_SCALAR >* bn) {}
333 template <
typename GUM_SCALAR >
338 template <
typename GUM_SCALAR >
343 template <
typename GUM_SCALAR >
356 for (
const auto node : this->
targets()) {
357 if (!
__graph.
exists(node) && !hard_ev_nodes.exists(node))
return true;
362 bool containing_clique_found =
false;
363 for (
const auto node : joint_target) {
367 for (
const auto xnode : joint_target) {
368 if (!clique.contains(xnode) && !hard_ev_nodes.exists(xnode)) {
373 }
catch (NotFound&) { found =
false; }
376 containing_clique_found =
true;
381 if (!containing_clique_found)
return true;
387 if ((change.second == EvidenceChangeType::EVIDENCE_ADDED)
398 template <
typename GUM_SCALAR >
414 const auto& bn = this->
BN();
416 for (
auto node : bn.dag())
428 target_nodes += nodeset;
433 if (target_nodes.size() != bn.size()) {
434 BarrenNodesFinder finder(&(bn.dag()));
435 finder.setTargets(&target_nodes);
438 for (
const auto& pair : this->
evidence()) {
439 evidence_nodes.
insert(pair.first);
441 finder.setEvidence(&evidence_nodes);
443 NodeSet barren_nodes = finder.barrenNodes();
446 for (
const auto node : barren_nodes) {
453 for (
const auto node :
__graph) {
454 const NodeSet& parents = bn.parents(node);
455 for (
auto iter1 = parents.cbegin(); iter1 != parents.cend(); ++iter1) {
456 __graph.addEdge(*iter1, node);
458 for (++iter2; iter2 != parents.cend(); ++iter2) {
459 __graph.addEdge(*iter1, *iter2);
468 for (
auto iter1 = nodeset.cbegin(); iter1 != nodeset.cend(); ++iter1) {
470 for (++iter2; iter2 != nodeset.cend(); ++iter2) {
471 __graph.addEdge(*iter1, *iter2);
479 __graph.eraseNode(node);
492 BinaryJoinTreeConverterDefault bjt_converter;
494 __JT =
new CliqueGraph(
495 bjt_converter.convert(triang_jt, this->domainSizes(), emptyset));
497 __JT =
new CliqueGraph(triang_jt);
505 const std::vector< NodeId >& JT_elim_order =
507 NodeProperty< int > elim_order(
Size(JT_elim_order.size()));
508 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size;
510 elim_order.insert(JT_elim_order[i], (
int)i);
511 const DAG& dag = bn.dag();
512 for (
const auto node : __graph) {
514 NodeId first_eliminated_node = node;
515 int elim_number = elim_order[first_eliminated_node];
517 for (
const auto parent : dag.parents(node)) {
518 if (__graph.existsNode(parent) && (elim_order[parent] < elim_number)) {
519 elim_number = elim_order[parent];
520 first_eliminated_node = parent;
535 for (
const auto node : __hard_ev_nodes) {
537 NodeSet pars(dag.parents(node).size());
538 for (
const auto par : dag.parents(node))
539 if (__graph.exists(par)) pars.
insert(par);
542 NodeId first_eliminated_node = *(pars.begin());
543 int elim_number = elim_order[first_eliminated_node];
545 for (
const auto parent : pars) {
546 if (elim_order[parent] < elim_number) {
547 elim_number = elim_order[parent];
548 first_eliminated_node = parent;
568 for (
const auto node : __hard_ev_nodes)
569 if (nodeset.contains(node)) nodeset.
erase(node);
571 if (!nodeset.empty()) {
574 NodeId first_eliminated_node = *(nodeset.begin());
575 int elim_number = elim_order[first_eliminated_node];
576 for (
const auto node : nodeset) {
577 if (elim_order[node] < elim_number) {
578 elim_number = elim_order[node];
579 first_eliminated_node = node;
593 for (
const auto& xpot : __clique_ss_potential) {
602 for (
const auto node : *
__JT) {
607 for (
auto& potlist : __created_potentials)
608 for (
auto pot : potlist.second)
610 __created_potentials.clear();
614 for (
auto pot_pair : __hard_ev_projected_CPTs)
615 delete pot_pair.second;
616 __hard_ev_projected_CPTs.clear();
624 __separator_potentials.clear();
625 __messages_computed.clear();
626 for (
const auto& edge : __JT->edges()) {
627 const Arc arc1(edge.first(), edge.second());
628 __separator_potentials.insert(arc1, empty_set);
629 __messages_computed.insert(arc1,
false);
630 const Arc arc2(Arc(edge.second(), edge.first()));
631 __separator_potentials.insert(arc2, empty_set);
632 __messages_computed.insert(arc2,
false);
636 for (
const auto& pot : __target_posteriors)
638 __target_posteriors.clear();
639 for (
const auto& pot : __joint_target_posteriors)
641 __joint_target_posteriors.clear();
650 for (
const auto node : dag) {
651 if (__graph.exists(node) || __hard_ev_nodes.contains(node)) {
652 const Potential< GUM_SCALAR >& cpt = bn.cpt(node);
656 const auto& variables = cpt.variablesSequence();
657 for (
const auto var : variables) {
658 NodeId xnode = bn.nodeId(*var);
659 if (__hard_ev_nodes.contains(xnode)) hard_nodes.insert(xnode);
665 if (hard_nodes.empty()) {
671 if (hard_nodes.size() == variables.size()) {
673 const auto& vars = cpt.variablesSequence();
674 for (
auto var : vars)
676 for (
Size i = 0; i < hard_nodes.size(); ++i) {
677 inst.chgVal(variables[i], hard_evidence[bn.nodeId(*(variables[i]))]);
682 Set< const DiscreteVariable* > hard_variables;
684 for (
const auto xnode : hard_nodes) {
685 marg_cpt_set.insert(
evidence[xnode]);
686 hard_variables.insert(&(bn.variable(xnode)));
690 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential >
693 combine_and_project.combineAndProject(marg_cpt_set, hard_variables);
696 if (new_cpt_list.size() != 1) {
698 for (
auto pot : new_cpt_list) {
699 if (!marg_cpt_set.contains(pot))
delete pot;
702 "the projection of a potential containing " 703 <<
"hard evidence is empty!");
705 const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
707 __hard_ev_projected_CPTs.insert(node, projected_cpt);
723 __clique_ss_potential.clear();
724 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
727 const auto& potset = xpotset.second;
728 if (potset.size() > 0) {
733 if (potset.size() == 1) {
734 __clique_ss_potential.insert(xpotset.first, *(potset.cbegin()));
736 auto joint = fast_combination.combine(potset);
737 __clique_ss_potential.insert(xpotset.first, joint);
743 __evidence_changes.clear();
749 template <
typename GUM_SCALAR >
764 template <
typename GUM_SCALAR >
768 invalidated_cliques.insert(to_id);
771 const Arc arc(from_id, to_id);
772 bool& message_computed = __messages_computed[arc];
773 if (message_computed) {
774 message_computed =
false;
775 __separator_potentials[arc].clear();
776 if (__created_potentials.exists(arc)) {
777 auto& arc_created_potentials = __created_potentials[arc];
778 for (
auto pot : arc_created_potentials)
780 arc_created_potentials.clear();
784 for (
const auto node_id : __JT->neighbours(to_id)) {
785 if (node_id != from_id)
794 template <
typename GUM_SCALAR >
800 NodeProperty< bool > ss_potential_to_deallocate(__clique_potentials.size());
801 for (
auto pot_iter = __clique_potentials.cbegin();
802 pot_iter != __clique_potentials.cend();
804 ss_potential_to_deallocate.insert(pot_iter.key(),
805 (pot_iter.val().size() > 1));
815 NodeSet hard_nodes_changed(__hard_ev_nodes.size());
816 for (
const auto node : __hard_ev_nodes)
817 if (__evidence_changes.exists(node)) hard_nodes_changed.
insert(node);
819 NodeSet nodes_with_projected_CPTs_changed;
820 const auto& bn = this->
BN();
821 for (
auto pot_iter = __hard_ev_projected_CPTs.beginSafe();
822 pot_iter != __hard_ev_projected_CPTs.endSafe();
824 for (
const auto var : bn.cpt(pot_iter.key()).variablesSequence()) {
825 if (hard_nodes_changed.contains(bn.nodeId(*var))) {
826 nodes_with_projected_CPTs_changed.insert(pot_iter.key());
827 delete pot_iter.val();
830 __hard_ev_projected_CPTs.erase(pot_iter);
844 NodeSet invalidated_cliques(__JT->size());
845 for (
const auto& pair : __evidence_changes) {
848 invalidated_cliques.
insert(clique);
849 for (
const auto neighbor : __JT->neighbours(clique)) {
857 for (
auto node : nodes_with_projected_CPTs_changed) {
859 invalidated_cliques.
insert(clique);
860 for (
const auto neighbor : __JT->neighbours(clique)) {
867 for (
const auto clique : invalidated_cliques) {
868 if (__clique_ss_potential.exists(clique)
869 && ss_potential_to_deallocate[clique]) {
870 delete __clique_ss_potential[clique];
879 for (
auto iter = __target_posteriors.beginSafe();
880 iter != __target_posteriors.endSafe();
882 if (__graph.exists(iter.key())
885 __target_posteriors.erase(iter);
890 for (
auto iter = __target_posteriors.beginSafe();
891 iter != __target_posteriors.endSafe();
893 if (hard_nodes_changed.contains(iter.key())) {
895 __target_posteriors.erase(iter);
900 for (
auto iter = __joint_target_posteriors.beginSafe();
901 iter != __joint_target_posteriors.endSafe();
905 __joint_target_posteriors.erase(iter);
913 __clique_potentials[
__node_to_clique[pot_pair.first]].erase(pot_pair.second);
915 __node_to_soft_evidence.clear();
919 __node_to_soft_evidence.insert(node,
evidence[node]);
930 for (
const auto node : nodes_with_projected_CPTs_changed) {
932 const Potential< GUM_SCALAR >& cpt = bn.cpt(node);
933 const auto& variables = cpt.variablesSequence();
934 Set< const DiscreteVariable* > hard_variables;
936 for (
const auto var : variables) {
937 NodeId xnode = bn.nodeId(*var);
938 if (__hard_ev_nodes.exists(xnode)) {
939 marg_cpt_set.insert(
evidence[xnode]);
940 hard_variables.insert(var);
945 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential >
948 combine_and_project.combineAndProject(marg_cpt_set, hard_variables);
951 if (new_cpt_list.size() != 1) {
953 for (
auto pot : new_cpt_list) {
954 if (!marg_cpt_set.contains(pot))
delete pot;
957 "the projection of a potential containing " 958 <<
"hard evidence is empty!");
960 const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
962 __hard_ev_projected_CPTs.insert(node, projected_cpt);
968 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
970 for (
const auto clique : invalidated_cliques) {
971 const auto& potset = __clique_potentials[clique];
973 if (potset.size() > 0) {
978 if (potset.size() == 1) {
979 __clique_ss_potential[clique] = *(potset.cbegin());
981 auto joint = fast_combination.combine(potset);
982 __clique_ss_potential[clique] = joint;
991 const Potential< GUM_SCALAR >& cpt = bn.cpt(node_cst.first);
992 const auto& variables = cpt.variablesSequence();
994 for (
const auto var : variables)
996 for (
const auto var : variables) {
997 inst.chgVal(var, hard_evidence[bn.nodeId(*var)]);
999 node_cst.second = cpt[inst];
1003 __evidence_changes.clear();
1008 template <
typename GUM_SCALAR >
1012 for (
const auto node : this->
targets()) {
1015 }
catch (Exception&) {}
1020 }
catch (Exception&) {}
1024 std::vector< std::pair< NodeId, Size > > possible_roots(clique_targets.size());
1025 const auto& bn = this->
BN();
1027 for (
const auto clique_id : clique_targets) {
1028 const auto& clique = __JT->clique(clique_id);
1030 for (
const auto node : clique) {
1031 dom_size *= bn.variable(node).domainSize();
1033 possible_roots[i] = std::pair< NodeId, Size >(clique_id, dom_size);
1038 std::sort(possible_roots.begin(),
1039 possible_roots.end(),
1040 [](
const std::pair< NodeId, Size >& a,
1041 const std::pair< NodeId, Size >& b) ->
bool {
1042 return a.second < b.second;
1046 NodeProperty< bool > marked = __JT->nodesProperty(
false);
1047 std::function< void(NodeId, NodeId) > diffuse_marks =
1048 [&marked, &diffuse_marks,
this](
NodeId node,
NodeId from) {
1049 if (!marked[node]) {
1050 marked[node] =
true;
1051 for (
const auto neigh : __JT->neighbours(node))
1052 if ((neigh != from) && !marked[neigh]) diffuse_marks(neigh, node);
1056 for (
const auto xclique : possible_roots) {
1057 NodeId clique = xclique.first;
1058 if (!marked[clique]) {
1060 diffuse_marks(clique, clique);
1067 template <
typename GUM_SCALAR >
1070 for (
const auto other : __JT->neighbours(
id)) {
1071 if ((other != from) && !__messages_computed[Arc(other,
id)])
1075 if ((
id != from) && !__messages_computed[Arc(
id, from)]) {
1082 template <
typename GUM_SCALAR >
1083 Set< const Potential< GUM_SCALAR >* >
1085 __PotentialSet& pot_list, Set< const DiscreteVariable* >& del_vars) {
1088 Set< const DiscreteVariable* > the_del_vars = del_vars;
1089 for (
auto iter = the_del_vars.beginSafe(); iter != the_del_vars.endSafe();
1091 NodeId id = this->
BN().nodeId(**iter);
1094 the_del_vars.erase(iter);
1099 HashTable< const DiscreteVariable*, __PotentialSet > var2pots;
1101 for (
const auto pot : pot_list) {
1102 const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
1103 for (
const auto var : vars) {
1104 if (the_del_vars.exists(var)) {
1105 if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
1106 var2pots[var].insert(pot);
1113 HashTable< const Potential< GUM_SCALAR >*, Set< const DiscreteVariable* > >
1115 Set< const DiscreteVariable* > empty_var_set;
1116 for (
auto elt : var2pots) {
1117 if (elt.second.size() == 1) {
1118 const Potential< GUM_SCALAR >* pot = *(elt.second.begin());
1119 if (!pot2barren_var.exists(pot)) {
1120 pot2barren_var.insert(pot, empty_var_set);
1122 pot2barren_var[pot].insert(elt.first);
1131 for (
auto elt : pot2barren_var) {
1134 const Potential< GUM_SCALAR >* pot = elt.first;
1135 pot_list.erase(pot);
1139 if (pot->variablesSequence().size() != elt.second.size()) {
1140 auto new_pot = projector.project(*pot, elt.second);
1141 pot_list.insert(new_pot);
1142 projected_pots.insert(new_pot);
1146 return projected_pots;
1151 template <
typename GUM_SCALAR >
1152 Set< const Potential< GUM_SCALAR >* >
1154 Set<
const Potential< GUM_SCALAR >* > pot_list,
1155 Set< const DiscreteVariable* >& del_vars,
1156 Set< const DiscreteVariable* >& kept_vars) {
1166 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential > combine_and_project(
1169 combine_and_project.combineAndProject(pot_list, del_vars);
1174 for (
auto iter = barren_projected_potentials.beginSafe();
1175 iter != barren_projected_potentials.endSafe();
1177 if (!new_pot_list.exists(*iter))
delete *iter;
1181 for (
auto iter_pot = new_pot_list.beginSafe();
1182 iter_pot != new_pot_list.endSafe();
1184 if ((*iter_pot)->variablesSequence().size() == 0) {
1191 new_pot_list.erase(iter_pot);
1195 return new_pot_list;
1200 template <
typename GUM_SCALAR >
1205 if (__clique_ss_potential.exists(from_id))
1206 pot_list.insert(__clique_ss_potential[from_id]);
1209 for (
const auto other_id : __JT->neighbours(from_id))
1210 if (other_id != to_id)
1211 pot_list += __separator_potentials[Arc(other_id, from_id)];
1214 const NodeSet& from_clique = __JT->clique(from_id);
1215 const NodeSet& separator = __JT->separator(from_id, to_id);
1216 Set< const DiscreteVariable* > del_vars(from_clique.size());
1217 Set< const DiscreteVariable* > kept_vars(separator.size());
1218 const auto& bn = this->
BN();
1220 for (
const auto node : from_clique) {
1221 if (!separator.contains(node)) {
1222 del_vars.insert(&(bn.variable(node)));
1224 kept_vars.insert(&(bn.variable(node)));
1234 for (
auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
1236 const auto pot = *iter;
1237 if (pot->variablesSequence().size() == 1) {
1238 bool is_all_ones =
true;
1239 for (Instantiation inst(*pot); !inst.end(); ++inst) {
1241 is_all_ones =
false;
1246 if (!pot_list.exists(pot))
delete pot;
1247 new_pot_list.erase(iter);
1255 const Arc arc(from_id, to_id);
1256 if (!new_pot_list.empty()) {
1257 if (new_pot_list.size() == 1) {
1260 auto pot = *(new_pot_list.begin());
1261 __separator_potentials[arc] = std::move(new_pot_list);
1262 if (!pot_list.exists(pot)) {
1263 if (!__created_potentials.exists(arc))
1265 __created_potentials[arc].insert(pot);
1269 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1271 auto joint = fast_combination.combine(new_pot_list);
1272 __separator_potentials[arc].insert(joint);
1273 if (!__created_potentials.exists(arc))
1275 __created_potentials[arc].insert(joint);
1278 for (
const auto pot : new_pot_list) {
1279 if (!pot_list.exists(pot)) {
delete pot; }
1284 __messages_computed[arc] =
true;
1289 template <
typename GUM_SCALAR >
1292 for (
const auto node : this->
targets()) {
1296 if (__graph.exists(node)) {
1311 template <
typename GUM_SCALAR >
1312 Potential< GUM_SCALAR >*
1314 const auto& bn = this->
BN();
1319 return new Potential< GUM_SCALAR >(*(this->
evidence()[id]));
1331 if (__clique_ss_potential.exists(clique_of_id))
1332 pot_list.insert(__clique_ss_potential[clique_of_id]);
1335 for (
const auto other : __JT->neighbours(clique_of_id))
1336 pot_list += __separator_potentials[Arc(other, clique_of_id)];
1339 const NodeSet& nodes = __JT->clique(clique_of_id);
1340 Set< const DiscreteVariable* > kept_vars{&(bn.variable(
id))};
1341 Set< const DiscreteVariable* > del_vars(nodes.size());
1342 for (
const auto node : nodes) {
1343 if (node !=
id) del_vars.
insert(&(bn.variable(node)));
1349 Potential< GUM_SCALAR >* joint =
nullptr;
1351 if (new_pot_list.size() == 1) {
1352 joint =
const_cast< Potential< GUM_SCALAR >*
>(*(new_pot_list.begin()));
1355 if (pot_list.exists(joint)) {
1356 joint =
new Potential< GUM_SCALAR >(*joint);
1360 new_pot_list.clear();
1363 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1365 joint = fast_combination.combine(new_pot_list);
1369 for (
auto pot : new_pot_list)
1370 if (!pot_list.exists(pot))
delete pot;
1375 bool nonzero_found =
false;
1376 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1377 if ((*joint)[inst]) {
1378 nonzero_found =
true;
1382 if (!nonzero_found) {
1386 "some evidence entered into the Bayes " 1387 "net are incompatible (their joint proba = 0)");
1394 template <
typename GUM_SCALAR >
1395 const Potential< GUM_SCALAR >&
1398 if (__target_posteriors.exists(
id)) {
return *(__target_posteriors[id]); }
1403 __target_posteriors.insert(
id, joint);
1410 template <
typename GUM_SCALAR >
1411 Potential< GUM_SCALAR >*
1418 if (targets.contains(node)) {
1419 targets.
erase(node);
1420 hard_ev_nodes.insert(node);
1427 if (targets.empty()) {
1429 for (
const auto node :
set) {
1432 if (pot_list.size() == 1) {
1433 auto pot =
new Potential< GUM_SCALAR >(**(pot_list.begin()));
1436 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1438 return fast_combination.combine(pot_list);
1449 }
catch (NotFound&) {
1455 for (
const auto node : targets) {
1456 if (!__graph.exists(node)) {
1457 GUM_ERROR(UndefinedElement, node <<
" is not a target node");
1463 const std::vector< NodeId >& JT_elim_order =
1465 NodeProperty< int > elim_order(
Size(JT_elim_order.size()));
1466 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size;
1468 elim_order.insert(JT_elim_order[i], (
int)i);
1469 NodeId first_eliminated_node = *(targets.begin());
1470 int elim_number = elim_order[first_eliminated_node];
1471 for (
const auto node : targets) {
1472 if (elim_order[node] < elim_number) {
1473 elim_number = elim_order[node];
1474 first_eliminated_node = node;
1481 const NodeSet& clique_nodes = __JT->clique(clique_of_set);
1482 for (
const auto node : targets) {
1483 if (!clique_nodes.contains(node)) {
1484 GUM_ERROR(UndefinedElement,
set <<
" is not a joint target");
1499 if (__clique_ss_potential.exists(clique_of_set))
1500 pot_list.insert(__clique_ss_potential[clique_of_set]);
1503 for (
const auto other : __JT->neighbours(clique_of_set))
1504 pot_list += __separator_potentials[Arc(other, clique_of_set)];
1507 const NodeSet& nodes = __JT->clique(clique_of_set);
1508 Set< const DiscreteVariable* > del_vars(nodes.size());
1509 Set< const DiscreteVariable* > kept_vars(targets.size());
1510 const auto& bn = this->
BN();
1511 for (
const auto node : nodes) {
1512 if (!targets.contains(node)) {
1513 del_vars.insert(&(bn.variable(node)));
1515 kept_vars.insert(&(bn.variable(node)));
1522 Potential< GUM_SCALAR >* joint =
nullptr;
1524 if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1525 joint =
const_cast< Potential< GUM_SCALAR >*
>(*(new_pot_list.begin()));
1528 if (pot_list.exists(joint)) {
1529 joint =
new Potential< GUM_SCALAR >(*joint);
1533 new_pot_list.clear();
1539 for (
const auto node : hard_ev_nodes) {
1540 new_new_pot_list.insert(
evidence[node]);
1542 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1544 joint = fast_combination.combine(new_new_pot_list);
1548 for (
auto pot : new_pot_list)
1549 if (!pot_list.exists(pot))
delete pot;
1553 bool nonzero_found =
false;
1554 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1555 if ((*joint)[inst]) {
1556 nonzero_found =
true;
1560 if (!nonzero_found) {
1564 "some evidence entered into the Bayes " 1565 "net are incompatible (their joint proba = 0)");
1573 template <
typename GUM_SCALAR >
1574 const Potential< GUM_SCALAR >&
1577 if (__joint_target_posteriors.exists(
set)) {
1578 return *(__joint_target_posteriors[
set]);
1584 __joint_target_posteriors.insert(
set, joint);
1591 template <
typename GUM_SCALAR >
1592 const Potential< GUM_SCALAR >&
1596 if (__joint_target_posteriors.exists(wanted_target))
1597 return *(__joint_target_posteriors[wanted_target]);
1603 if (!__joint_target_posteriors.exists(declared_target)) {
1608 const auto& bn = this->
BN();
1609 Set< const DiscreteVariable* > del_vars;
1610 for (
const auto node : declared_target)
1611 if (!wanted_target.contains(node)) del_vars.insert(&(bn.variable(node)));
1612 auto pot =
new Potential< GUM_SCALAR >(
1613 __joint_target_posteriors[declared_target]->margSumOut(del_vars));
1616 __joint_target_posteriors.insert(wanted_target, pot);
1622 template <
typename GUM_SCALAR >
1633 GUM_SCALAR prob_ev = 1;
1634 for (
const auto root :
__roots) {
1636 NodeId node = *(__JT->clique(root).begin());
1639 for (Instantiation iter(*tmp); !iter.end(); ++iter)
1640 sum += tmp->get(iter);
1645 for (
const auto& projected_cpt : __constants)
1646 prob_ev *= projected_cpt.second;
1654 #endif // DOXYGEN_SHOULD_SKIP_THIS NodeProperty< const Potential< GUM_SCALAR > *> __clique_ss_potential
the potentials stored into the cliques by Shafer-Shenoy
void _onAllJointTargetsErased() final
fired before a all the joint targets are removed
void _onEvidenceChanged(const NodeId id, bool hasChangedSoftHard) final
fired after an evidence is changed, in particular when its status (soft/hard) changes ...
FindBarrenNodesType __barren_nodes_type
the type of barren nodes computation we wish
void __setProjectionFunction(Potential< GUM_SCALAR > *(*proj)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
sets the operator for performing the projections
virtual void clear()
removes all the nodes and edges from the graph
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
const NodeProperty< const Potential< GUM_SCALAR > *> & evidence() const
returns the set of evidence
NodeProperty< GUM_SCALAR > __constants
the constants resulting from the projections of CPTs defined over only hard evidence nodes remove th...
JunctionTree * __junctionTree
the junction tree to answer the last inference query
void __collectMessage(NodeId id, NodeId from)
actually perform the collect phase
void _onEvidenceErased(const NodeId id, bool isHardEvidence) final
fired before an evidence is removed
__PotentialSet __marginalizeOut(__PotentialSet pot_list, Set< const DiscreteVariable * > &del_vars, Set< const DiscreteVariable * > &kept_vars)
removes variables del_vars from a list of potentials and returns the resulting list ...
void _onMarginalTargetErased(const NodeId id) final
fired before a single target is removed
void _onMarginalTargetAdded(const NodeId id) final
fired after a new single target is inserted
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
ArcProperty< bool > __messages_computed
indicates whether a message (from one clique to another) has been computed
Set< const Potential< GUM_SCALAR > *> __PotentialSet
void __setCombinationFunction(Potential< GUM_SCALAR > *(*comb)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &))
sets the operator for performing the combinations
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const GUM_SCALAR __1_minus_epsilon
for comparisons with 1 - epsilon
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Potential< GUM_SCALAR > * _unnormalizedJointPosterior(NodeId id) final
returns a fresh potential equal to P(argument,evidence)
void _onAllMarginalTargetsAdded() final
fired after all the nodes of the BN are added as single targets
NodeSet __roots
a clique node used as a root in each connected component of __JT
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
void erase(const Key &k)
Erases an element from the set.
bool exists(const NodeId id) const
alias for existsNode
GUM_SCALAR evidenceProbability()
returns the probability of evidence
void _makeInference() final
called when the inference has to be performed effectively
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
HashTable< NodeSet, const Potential< GUM_SCALAR > *> __joint_target_posteriors
the set of set target posteriors computed during the last inference
void _updateOutdatedBNPotentials() final
prepares inference when the latter is in OutdatedBNPotentials state
UndiGraph __graph
the undigraph extracted from the BN and used to construct the join tree
virtual void eraseNode(const NodeId id)
remove a node and its adjacent edges from the graph
virtual void makeInference() final
perform the heavy computations needed to compute the targets' posteriors
ShaferShenoyInference(const IBayesNet< GUM_SCALAR > *BN, FindBarrenNodesType barren_type=FindBarrenNodesType::FIND_BARREN_NODES, bool use_binary_join_tree=true)
default constructor
NodeProperty< EvidenceChangeType > __evidence_changes
indicates which nodes of the BN have evidence that changed since the last inference ...
FindBarrenNodesType
type of algorithm to determine barren nodes
NodeProperty< const Potential< GUM_SCALAR > *> __node_to_soft_evidence
the soft evidence stored in the cliques per their assigned node in the BN
const NodeSet & softEvidenceNodes() const
returns the set of nodes with soft evidence
HashTable< NodeId, NodeId > __node_to_clique
for each node of __graph (~ in the Bayes net), associate an ID in the JT
virtual const NodeProperty< Size > & domainSizes() const final
get the domain sizes of the random variables of the BN
NodeProperty< const Potential< GUM_SCALAR > *> __target_posteriors
the set of single posteriors computed during the last inference
NodeProperty< const Potential< GUM_SCALAR > *> __hard_ev_projected_CPTs
the CPTs that were projected due to hard evidence nodes
virtual NodeId createdJunctionTreeClique(const NodeId id)=0
returns the Id of the clique created by the elimination of a given node during the triangulation proc...
void _onJointTargetAdded(const NodeSet &set) final
fired after a new joint target is inserted
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
bool __isNewJTNeeded() const
check whether a new join tree is really needed for the next inference
void _updateOutdatedBNStructure() final
prepares inference when the latter is in OutdatedBNStructure state
const NodeProperty< Idx > & hardEvidence() const
indicate for each node with hard evidence which value it took
virtual bool isDone() const noexcept final
returns whether the inference object is in a done state
void setTriangulation(const Triangulation &new_triangulation)
use a new triangulation algorithm
CliqueGraph JoinTree
a join tree is a clique graph satisfying the running intersection property (but some cliques may be i...
const Potential< GUM_SCALAR > & _jointPosterior(const NodeSet &set) final
returns the posterior of a declared target set
virtual const Set< NodeSet > & jointTargets() const noexcept final
returns the list of joint targets
void _onAllTargetsErased() final
fired before a all single and joint_targets are removed
virtual void _onBayesNetChanged(const IBayesNet< GUM_SCALAR > *bn) final
fired after a new Bayes net has been assigned to the engine
bool __is_new_jt_needed
indicates whether a new join tree is needed for the next inference
static INLINE Potential< GUM_SCALAR > * SSNewprojPotential(const Potential< GUM_SCALAR > &t1, const Set< const DiscreteVariable * > &del_vars)
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
void setFindBarrenNodesType(FindBarrenNodesType type)
sets how we determine barren nodes
void _onJointTargetErased(const NodeSet &set) final
fired before a joint target is removed
HashTable< NodeSet, NodeId > __joint_target_to_clique
for each set target, assign a clique in the JT that contains it
void _setOutdatedBNStructureState()
put the inference into an outdated BN structure state
~ShaferShenoyInference() final
destructor
const NodeSet & clique(const NodeId idClique) const
returns the set of nodes included into a given clique
ArcProperty< __PotentialSet > __created_potentials
the set of potentials created for the last inference messages
const NodeSet & hardEvidenceNodes() const
returns the set of nodes with hard evidence
Triangulation * __triangulation
the triangulation class creating the junction tree used for inference
void __createNewJT()
create a new junction tree as well as its related data structures
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
void clear()
Removes all the elements in the hash table.
void __diffuseMessageInvalidations(NodeId from, NodeId to, NodeSet &cliques_invalidated)
invalidate all the messages sent from a given clique
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
__PotentialSet __removeBarrenVariables(__PotentialSet &pot_list, Set< const DiscreteVariable * > &del_vars)
JoinTree * __JT
the join (or junction) tree used to answer the last inference query
void __computeJoinTreeRoots()
compute a root for each connected component of __JT
virtual bool isInferenceReady() const noexcept final
returns whether the inference object is in a ready state
void _onEvidenceAdded(const NodeId id, bool isHardEvidence) final
fired after a new evidence is inserted
virtual Triangulation * newFactory() const =0
returns a fresh triangulation of the same type as the current object but with an empty graph ...
virtual const NodeSet & targets() const noexcept final
returns the list of marginal targets
NodeSet __hard_ev_nodes
the hard evidence nodes which were projected in CPTs
bool __use_binary_join_tree
indicates whether we should transform junction trees into binary join trees
void clear()
Removes all the elements, if any, from the set.
std::size_t Size
In aGrUM, hashed values are unsigned long int.
NodeProperty< __PotentialSet > __clique_potentials
the list of all potentials stored in the cliques
Potential< GUM_SCALAR > *(* __combination_op)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &)
the operator for performing the combinations
const Potential< GUM_SCALAR > & _posterior(NodeId id) final
returns the posterior of a given variable
virtual const CliqueGraph & junctionTree()=0
returns a compatible junction tree
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
const JoinTree * joinTree()
returns the current join tree used
virtual const std::vector< NodeId > & eliminationOrder()=0
returns an elimination ordering compatible with the triangulated graph
void _onAllEvidenceErased(bool contains_hard_evidence) final
fired before all the evidence are erased
void __invalidateAllMessages()
invalidate all messages, posteriors and created potentials
virtual const IBayesNet< GUM_SCALAR > & BN() const final
Returns a constant reference over the IBayesNet referenced by this class.
void _onAllMarginalTargetsErased() final
fired before a all the single targets are removed
Potential< GUM_SCALAR > *(* __projection_op)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &)
the operator for performing the projections
Size NodeId
Type for node ids.
void insert(const Key &k)
Inserts a new element into the set.
void __produceMessage(NodeId from_id, NodeId to_id)
creates the message sent by clique from_id to clique to_id
#define GUM_ERROR(type, msg)
virtual void setGraph(const UndiGraph *graph, const NodeProperty< Size > *domsizes)=0
initialize the triangulation data structures for a new graph
void _setOutdatedBNPotentialsState()
puts the inference into an OutdatedBNPotentials state if it is not already in an OutdatedBNStructure ...
ArcProperty< __PotentialSet > __separator_potentials
the list of all potentials stored in the separators after inferences
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const JunctionTree * junctionTree()
returns the current junction tree