26 #ifndef DOXYGEN_SHOULD_SKIP_THIS 40 template <
typename GUM_SCALAR >
42 const IBayesNet< GUM_SCALAR >* BN,
44 bool use_binary_join_tree) :
45 JointTargetedInference< GUM_SCALAR >(BN),
46 EvidenceInference< GUM_SCALAR >(BN),
47 __use_binary_join_tree(use_binary_join_tree) {
60 template <
typename GUM_SCALAR >
64 for (
const auto pot : pots.second)
98 template <
typename GUM_SCALAR >
100 const Triangulation& new_triangulation) {
109 template <
typename GUM_SCALAR >
117 template <
typename GUM_SCALAR >
126 template <
typename GUM_SCALAR >
128 Potential< GUM_SCALAR >* (*proj)(
const Potential< GUM_SCALAR >&,
129 const Set< const DiscreteVariable* >&)) {
135 template <
typename GUM_SCALAR >
137 Potential< GUM_SCALAR >* (*comb)(
const Potential< GUM_SCALAR >&,
138 const Potential< GUM_SCALAR >&)) {
144 template <
typename GUM_SCALAR >
148 potset.second.clear();
150 mess_computed.second =
false;
153 for (
const auto& potset : __created_potentials)
154 for (
const auto pot : potset.second)
158 for (
const auto& pot : __target_posteriors)
160 for (
const auto& pot : __joint_target_posteriors)
170 template <
typename GUM_SCALAR >
182 "setFindBarrenNodesType for type " 183 << (
unsigned int)type <<
" is not implemented yet");
195 template <
typename GUM_SCALAR >
198 bool isHardEvidence) {
207 }
catch (DuplicateElement&) {
219 template <
typename GUM_SCALAR >
222 bool isHardEvidence) {
230 }
catch (DuplicateElement&) {
246 template <
typename GUM_SCALAR >
248 bool has_hard_evidence) {
255 }
catch (DuplicateElement&) {
272 template <
typename GUM_SCALAR >
274 const NodeId id,
bool hasChangedSoftHard) {
275 if (hasChangedSoftHard)
280 }
catch (DuplicateElement&) {
290 template <
typename GUM_SCALAR >
297 template <
typename GUM_SCALAR >
303 template <
typename GUM_SCALAR >
310 template <
typename GUM_SCALAR >
316 template <
typename GUM_SCALAR >
321 template <
typename GUM_SCALAR >
325 template <
typename GUM_SCALAR >
327 const IBayesNet< GUM_SCALAR >* bn) {}
330 template <
typename GUM_SCALAR >
335 template <
typename GUM_SCALAR >
340 template <
typename GUM_SCALAR >
353 for (
const auto node : this->
targets()) {
354 if (!
__graph.
exists(node) && !hard_ev_nodes.exists(node))
return true;
358 bool containing_clique_found =
false;
359 for (
const auto node : nodes) {
362 for (
const auto xnode : nodes) {
363 if (!clique.contains(xnode) && !hard_ev_nodes.exists(xnode)) {
369 containing_clique_found =
true;
374 if (!containing_clique_found)
return true;
380 if ((change.second == EvidenceChangeType::EVIDENCE_ADDED)
391 template <
typename GUM_SCALAR >
407 const auto& bn = this->
BN();
409 for (
auto node : bn.dag())
421 target_nodes += nodeset;
426 if (target_nodes.size() != bn.size()) {
427 BarrenNodesFinder finder(&(bn.dag()));
428 finder.setTargets(&target_nodes);
431 for (
const auto& pair : this->
evidence()) {
432 evidence_nodes.
insert(pair.first);
434 finder.setEvidence(&evidence_nodes);
436 NodeSet barren_nodes = finder.barrenNodes();
439 for (
const auto node : barren_nodes) {
446 for (
const auto node :
__graph) {
447 const NodeSet& parents = bn.parents(node);
448 for (
auto iter1 = parents.cbegin(); iter1 != parents.cend(); ++iter1) {
449 __graph.addEdge(*iter1, node);
451 for (++iter2; iter2 != parents.cend(); ++iter2) {
452 __graph.addEdge(*iter1, *iter2);
461 for (
auto iter1 = nodeset.cbegin(); iter1 != nodeset.cend(); ++iter1) {
463 for (++iter2; iter2 != nodeset.cend(); ++iter2) {
464 __graph.addEdge(*iter1, *iter2);
472 __graph.eraseNode(node);
485 BinaryJoinTreeConverterDefault bjt_converter;
487 __JT =
new CliqueGraph(
488 bjt_converter.convert(triang_jt, this->domainSizes(), emptyset));
490 __JT =
new CliqueGraph(triang_jt);
498 const std::vector< NodeId >& JT_elim_order =
500 NodeProperty< int > elim_order(
Size(JT_elim_order.size()));
501 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size;
503 elim_order.insert(JT_elim_order[i], (
int)i);
504 const DAG& dag = bn.dag();
505 for (
const auto node : __graph) {
507 NodeId first_eliminated_node = node;
508 int elim_number = elim_order[first_eliminated_node];
510 for (
const auto parent : dag.parents(node)) {
511 if (__graph.existsNode(parent) && (elim_order[parent] < elim_number)) {
512 elim_number = elim_order[parent];
513 first_eliminated_node = parent;
528 for (
const auto node : __hard_ev_nodes) {
530 NodeSet pars(dag.parents(node).size());
531 for (
const auto par : dag.parents(node))
532 if (__graph.exists(par)) pars.
insert(par);
535 NodeId first_eliminated_node = *(pars.begin());
536 int elim_number = elim_order[first_eliminated_node];
538 for (
const auto parent : pars) {
539 if (elim_order[parent] < elim_number) {
540 elim_number = elim_order[parent];
541 first_eliminated_node = parent;
561 for (
const auto node : __hard_ev_nodes)
562 if (nodeset.contains(node)) nodeset.
erase(node);
564 if (!nodeset.empty()) {
567 NodeId first_eliminated_node = *(nodeset.begin());
568 int elim_number = elim_order[first_eliminated_node];
569 for (
const auto node : nodeset) {
570 if (elim_order[node] < elim_number) {
571 elim_number = elim_order[node];
572 first_eliminated_node = node;
586 for (
const auto& xpot : __clique_ss_potential) {
595 for (
const auto node : *
__JT) {
600 for (
auto& potlist : __created_potentials)
601 for (
auto pot : potlist.second)
603 __created_potentials.clear();
607 for (
auto pot_pair : __hard_ev_projected_CPTs)
608 delete pot_pair.second;
609 __hard_ev_projected_CPTs.clear();
617 __separator_potentials.clear();
618 __messages_computed.clear();
619 for (
const auto& edge : __JT->edges()) {
620 const Arc arc1(edge.first(), edge.second());
621 __separator_potentials.insert(arc1, empty_set);
622 __messages_computed.insert(arc1,
false);
623 const Arc arc2(Arc(edge.second(), edge.first()));
624 __separator_potentials.insert(arc2, empty_set);
625 __messages_computed.insert(arc2,
false);
629 for (
const auto& pot : __target_posteriors)
631 __target_posteriors.clear();
632 for (
const auto& pot : __joint_target_posteriors)
634 __joint_target_posteriors.clear();
643 for (
const auto node : dag) {
644 if (__graph.exists(node) || __hard_ev_nodes.contains(node)) {
645 const Potential< GUM_SCALAR >& cpt = bn.cpt(node);
649 const auto& variables = cpt.variablesSequence();
650 for (
const auto var : variables) {
651 NodeId xnode = bn.nodeId(*var);
652 if (__hard_ev_nodes.contains(xnode)) hard_nodes.insert(xnode);
658 if (hard_nodes.empty()) {
664 if (hard_nodes.size() == variables.size()) {
666 const auto& vars = cpt.variablesSequence();
667 for (
auto var : vars)
669 for (
Size i = 0; i < hard_nodes.size(); ++i) {
670 inst.chgVal(variables[i], hard_evidence[bn.nodeId(*(variables[i]))]);
675 Set< const DiscreteVariable* > hard_variables;
677 for (
const auto xnode : hard_nodes) {
678 marg_cpt_set.insert(
evidence[xnode]);
679 hard_variables.insert(&(bn.variable(xnode)));
683 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential >
686 combine_and_project.combineAndProject(marg_cpt_set, hard_variables);
689 if (new_cpt_list.size() != 1) {
691 for (
auto pot : new_cpt_list) {
692 if (!marg_cpt_set.contains(pot))
delete pot;
695 "the projection of a potential containing " 696 <<
"hard evidence is empty!");
698 const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
700 __hard_ev_projected_CPTs.insert(node, projected_cpt);
716 __clique_ss_potential.clear();
717 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
720 const auto& potset = xpotset.second;
721 if (potset.size() > 0) {
726 if (potset.size() == 1) {
727 __clique_ss_potential.insert(xpotset.first, *(potset.cbegin()));
729 auto joint = fast_combination.combine(potset);
730 __clique_ss_potential.insert(xpotset.first, joint);
736 __evidence_changes.clear();
742 template <
typename GUM_SCALAR >
757 template <
typename GUM_SCALAR >
761 invalidated_cliques.insert(to_id);
764 const Arc arc(from_id, to_id);
765 bool& message_computed = __messages_computed[arc];
766 if (message_computed) {
767 message_computed =
false;
768 __separator_potentials[arc].clear();
769 if (__created_potentials.exists(arc)) {
770 auto& arc_created_potentials = __created_potentials[arc];
771 for (
auto pot : arc_created_potentials)
773 arc_created_potentials.clear();
777 for (
const auto node_id : __JT->neighbours(to_id)) {
778 if (node_id != from_id)
787 template <
typename GUM_SCALAR >
793 NodeProperty< bool > ss_potential_to_deallocate(__clique_potentials.size());
794 for (
auto pot_iter = __clique_potentials.cbegin();
795 pot_iter != __clique_potentials.cend();
797 ss_potential_to_deallocate.insert(pot_iter.key(),
798 (pot_iter.val().size() > 1));
808 NodeSet hard_nodes_changed(__hard_ev_nodes.size());
809 for (
const auto node : __hard_ev_nodes)
810 if (__evidence_changes.exists(node)) hard_nodes_changed.
insert(node);
812 NodeSet nodes_with_projected_CPTs_changed;
813 const auto& bn = this->
BN();
814 for (
auto pot_iter = __hard_ev_projected_CPTs.beginSafe();
815 pot_iter != __hard_ev_projected_CPTs.endSafe();
817 for (
const auto var : bn.cpt(pot_iter.key()).variablesSequence()) {
818 if (hard_nodes_changed.contains(bn.nodeId(*var))) {
819 nodes_with_projected_CPTs_changed.insert(pot_iter.key());
820 delete pot_iter.val();
823 __hard_ev_projected_CPTs.erase(pot_iter);
837 NodeSet invalidated_cliques(__JT->size());
838 for (
const auto& pair : __evidence_changes) {
841 invalidated_cliques.
insert(clique);
842 for (
const auto neighbor : __JT->neighbours(clique)) {
850 for (
auto node : nodes_with_projected_CPTs_changed) {
852 invalidated_cliques.
insert(clique);
853 for (
const auto neighbor : __JT->neighbours(clique)) {
860 for (
const auto clique : invalidated_cliques) {
861 if (__clique_ss_potential.exists(clique)
862 && ss_potential_to_deallocate[clique]) {
863 delete __clique_ss_potential[clique];
872 for (
auto iter = __target_posteriors.beginSafe();
873 iter != __target_posteriors.endSafe();
875 if (__graph.exists(iter.key())
878 __target_posteriors.erase(iter);
883 for (
auto iter = __target_posteriors.beginSafe();
884 iter != __target_posteriors.endSafe();
886 if (hard_nodes_changed.contains(iter.key())) {
888 __target_posteriors.erase(iter);
893 for (
auto iter = __joint_target_posteriors.beginSafe();
894 iter != __joint_target_posteriors.endSafe();
898 __joint_target_posteriors.erase(iter);
906 __clique_potentials[
__node_to_clique[pot_pair.first]].erase(pot_pair.second);
908 __node_to_soft_evidence.clear();
912 __node_to_soft_evidence.insert(node,
evidence[node]);
923 for (
const auto node : nodes_with_projected_CPTs_changed) {
925 const Potential< GUM_SCALAR >& cpt = bn.cpt(node);
926 const auto& variables = cpt.variablesSequence();
927 Set< const DiscreteVariable* > hard_variables;
929 for (
const auto var : variables) {
930 NodeId xnode = bn.nodeId(*var);
931 if (__hard_ev_nodes.exists(xnode)) {
932 marg_cpt_set.insert(
evidence[xnode]);
933 hard_variables.insert(var);
938 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential >
941 combine_and_project.combineAndProject(marg_cpt_set, hard_variables);
944 if (new_cpt_list.size() != 1) {
946 for (
auto pot : new_cpt_list) {
947 if (!marg_cpt_set.contains(pot))
delete pot;
950 "the projection of a potential containing " 951 <<
"hard evidence is empty!");
953 const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
955 __hard_ev_projected_CPTs.insert(node, projected_cpt);
961 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
963 for (
const auto clique : invalidated_cliques) {
964 const auto& potset = __clique_potentials[clique];
966 if (potset.size() > 0) {
971 if (potset.size() == 1) {
972 __clique_ss_potential[clique] = *(potset.cbegin());
974 auto joint = fast_combination.combine(potset);
975 __clique_ss_potential[clique] = joint;
984 const Potential< GUM_SCALAR >& cpt = bn.cpt(node_cst.first);
985 const auto& variables = cpt.variablesSequence();
987 for (
const auto var : variables)
989 for (
const auto var : variables) {
990 inst.chgVal(var, hard_evidence[bn.nodeId(*var)]);
992 node_cst.second = cpt[inst];
996 __evidence_changes.clear();
1001 template <
typename GUM_SCALAR >
1005 for (
const auto node : this->
targets()) {
1008 }
catch (Exception&) {}
1013 }
catch (Exception&) {}
1017 std::vector< std::pair< NodeId, Size > > possible_roots(clique_targets.size());
1018 const auto& bn = this->
BN();
1020 for (
const auto clique_id : clique_targets) {
1021 const auto& clique = __JT->clique(clique_id);
1023 for (
const auto node : clique) {
1024 dom_size *= bn.variable(node).domainSize();
1026 possible_roots[i] = std::pair< NodeId, Size >(clique_id, dom_size);
1031 std::sort(possible_roots.begin(),
1032 possible_roots.end(),
1033 [](
const std::pair< NodeId, Size >& a,
1034 const std::pair< NodeId, Size >& b) ->
bool {
1035 return a.second < b.second;
1039 NodeProperty< bool > marked = __JT->nodesProperty(
false);
1040 std::function< void(NodeId, NodeId) > diffuse_marks =
1041 [&marked, &diffuse_marks,
this](
NodeId node,
NodeId from) {
1042 if (!marked[node]) {
1043 marked[node] =
true;
1044 for (
const auto neigh : __JT->neighbours(node))
1045 if ((neigh != from) && !marked[neigh]) diffuse_marks(neigh, node);
1049 for (
const auto xclique : possible_roots) {
1050 NodeId clique = xclique.first;
1051 if (!marked[clique]) {
1053 diffuse_marks(clique, clique);
1060 template <
typename GUM_SCALAR >
1063 for (
const auto other : __JT->neighbours(
id)) {
1064 if ((other != from) && !__messages_computed[Arc(other,
id)])
1068 if ((
id != from) && !__messages_computed[Arc(
id, from)]) {
1075 template <
typename GUM_SCALAR >
1076 Set< const Potential< GUM_SCALAR >* >
1078 __PotentialSet& pot_list, Set< const DiscreteVariable* >& del_vars) {
1081 Set< const DiscreteVariable* > the_del_vars = del_vars;
1082 for (
auto iter = the_del_vars.beginSafe(); iter != the_del_vars.endSafe();
1084 NodeId id = this->
BN().nodeId(**iter);
1087 the_del_vars.erase(iter);
1092 HashTable< const DiscreteVariable*, __PotentialSet > var2pots;
1094 for (
const auto pot : pot_list) {
1095 const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
1096 for (
const auto var : vars) {
1097 if (the_del_vars.exists(var)) {
1098 if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
1099 var2pots[var].insert(pot);
1106 HashTable< const Potential< GUM_SCALAR >*, Set< const DiscreteVariable* > >
1108 Set< const DiscreteVariable* > empty_var_set;
1109 for (
auto elt : var2pots) {
1110 if (elt.second.size() == 1) {
1111 const Potential< GUM_SCALAR >* pot = *(elt.second.begin());
1112 if (!pot2barren_var.exists(pot)) {
1113 pot2barren_var.insert(pot, empty_var_set);
1115 pot2barren_var[pot].insert(elt.first);
1124 for (
auto elt : pot2barren_var) {
1127 const Potential< GUM_SCALAR >* pot = elt.first;
1128 pot_list.erase(pot);
1132 if (pot->variablesSequence().size() != elt.second.size()) {
1133 auto new_pot = projector.project(*pot, elt.second);
1134 pot_list.insert(new_pot);
1135 projected_pots.insert(new_pot);
1139 return projected_pots;
1144 template <
typename GUM_SCALAR >
1145 Set< const Potential< GUM_SCALAR >* >
1147 Set<
const Potential< GUM_SCALAR >* > pot_list,
1148 Set< const DiscreteVariable* >& del_vars,
1149 Set< const DiscreteVariable* >& kept_vars) {
1159 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential > combine_and_project(
1162 combine_and_project.combineAndProject(pot_list, del_vars);
1167 for (
auto iter = barren_projected_potentials.beginSafe();
1168 iter != barren_projected_potentials.endSafe();
1170 if (!new_pot_list.exists(*iter))
delete *iter;
1174 for (
auto iter_pot = new_pot_list.beginSafe();
1175 iter_pot != new_pot_list.endSafe();
1177 if ((*iter_pot)->variablesSequence().size() == 0) {
1184 new_pot_list.erase(iter_pot);
1188 return new_pot_list;
1193 template <
typename GUM_SCALAR >
1198 if (__clique_ss_potential.exists(from_id))
1199 pot_list.insert(__clique_ss_potential[from_id]);
1202 for (
const auto other_id : __JT->neighbours(from_id))
1203 if (other_id != to_id)
1204 pot_list += __separator_potentials[Arc(other_id, from_id)];
1207 const NodeSet& from_clique = __JT->clique(from_id);
1208 const NodeSet& separator = __JT->separator(from_id, to_id);
1209 Set< const DiscreteVariable* > del_vars(from_clique.size());
1210 Set< const DiscreteVariable* > kept_vars(separator.size());
1211 const auto& bn = this->
BN();
1213 for (
const auto node : from_clique) {
1214 if (!separator.contains(node)) {
1215 del_vars.insert(&(bn.variable(node)));
1217 kept_vars.insert(&(bn.variable(node)));
1227 for (
auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
1229 const auto pot = *iter;
1230 if (pot->variablesSequence().size() == 1) {
1231 bool is_all_ones =
true;
1232 for (Instantiation inst(*pot); !inst.end(); ++inst) {
1234 is_all_ones =
false;
1239 if (!pot_list.exists(pot))
delete pot;
1240 new_pot_list.erase(iter);
1248 const Arc arc(from_id, to_id);
1249 if (!new_pot_list.empty()) {
1250 if (new_pot_list.size() == 1) {
1253 auto pot = *(new_pot_list.begin());
1254 __separator_potentials[arc] = std::move(new_pot_list);
1255 if (!pot_list.exists(pot)) {
1256 if (!__created_potentials.exists(arc))
1258 __created_potentials[arc].insert(pot);
1262 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1264 auto joint = fast_combination.combine(new_pot_list);
1265 __separator_potentials[arc].insert(joint);
1266 if (!__created_potentials.exists(arc))
1268 __created_potentials[arc].insert(joint);
1271 for (
const auto pot : new_pot_list) {
1272 if (!pot_list.exists(pot)) {
delete pot; }
1277 __messages_computed[arc] =
true;
1282 template <
typename GUM_SCALAR >
1285 for (
const auto node : this->
targets()) {
1289 if (__graph.exists(node)) {
1304 template <
typename GUM_SCALAR >
1305 Potential< GUM_SCALAR >*
1307 const auto& bn = this->
BN();
1312 return new Potential< GUM_SCALAR >(*(this->
evidence()[id]));
1324 if (__clique_ss_potential.exists(clique_of_id))
1325 pot_list.insert(__clique_ss_potential[clique_of_id]);
1328 for (
const auto other : __JT->neighbours(clique_of_id))
1329 pot_list += __separator_potentials[Arc(other, clique_of_id)];
1332 const NodeSet& nodes = __JT->clique(clique_of_id);
1333 Set< const DiscreteVariable* > kept_vars{&(bn.variable(
id))};
1334 Set< const DiscreteVariable* > del_vars(nodes.size());
1335 for (
const auto node : nodes) {
1336 if (node !=
id) del_vars.
insert(&(bn.variable(node)));
1342 Potential< GUM_SCALAR >* joint =
nullptr;
1344 if (new_pot_list.size() == 1) {
1345 joint =
const_cast< Potential< GUM_SCALAR >*
>(*(new_pot_list.begin()));
1348 if (pot_list.exists(joint)) {
1349 joint =
new Potential< GUM_SCALAR >(*joint);
1353 new_pot_list.clear();
1356 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1358 joint = fast_combination.combine(new_pot_list);
1362 for (
auto pot : new_pot_list)
1363 if (!pot_list.exists(pot))
delete pot;
1368 bool nonzero_found =
false;
1369 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1370 if ((*joint)[inst]) {
1371 nonzero_found =
true;
1375 if (!nonzero_found) {
1379 "some evidence entered into the Bayes " 1380 "net are incompatible (their joint proba = 0)");
1387 template <
typename GUM_SCALAR >
1388 const Potential< GUM_SCALAR >&
1391 if (__target_posteriors.exists(
id)) {
return *(__target_posteriors[id]); }
1396 __target_posteriors.insert(
id, joint);
1403 template <
typename GUM_SCALAR >
1404 Potential< GUM_SCALAR >*
1411 if (targets.contains(node)) {
1412 targets.
erase(node);
1413 hard_ev_nodes.insert(node);
1420 if (targets.empty()) {
1422 for (
const auto node :
set) {
1425 if (pot_list.size() == 1) {
1426 auto pot =
new Potential< GUM_SCALAR >(**(pot_list.begin()));
1429 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1431 return fast_combination.combine(pot_list);
1442 }
catch (NotFound&) {
1448 for (
const auto node : targets) {
1449 if (!__graph.exists(node)) {
1450 GUM_ERROR(UndefinedElement, node <<
" is not a target node");
1456 const std::vector< NodeId >& JT_elim_order =
1458 NodeProperty< int > elim_order(
Size(JT_elim_order.size()));
1459 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size;
1461 elim_order.insert(JT_elim_order[i], (
int)i);
1462 NodeId first_eliminated_node = *(targets.begin());
1463 int elim_number = elim_order[first_eliminated_node];
1464 for (
const auto node : targets) {
1465 if (elim_order[node] < elim_number) {
1466 elim_number = elim_order[node];
1467 first_eliminated_node = node;
1474 const NodeSet& clique_nodes = __JT->clique(clique_of_set);
1475 for (
const auto node : targets) {
1476 if (!clique_nodes.contains(node)) {
1477 GUM_ERROR(UndefinedElement,
set <<
" is not a joint target");
1492 if (__clique_ss_potential.exists(clique_of_set))
1493 pot_list.insert(__clique_ss_potential[clique_of_set]);
1496 for (
const auto other : __JT->neighbours(clique_of_set))
1497 pot_list += __separator_potentials[Arc(other, clique_of_set)];
1500 const NodeSet& nodes = __JT->clique(clique_of_set);
1501 Set< const DiscreteVariable* > del_vars(nodes.size());
1502 Set< const DiscreteVariable* > kept_vars(targets.size());
1503 const auto& bn = this->
BN();
1504 for (
const auto node : nodes) {
1505 if (!targets.contains(node)) {
1506 del_vars.insert(&(bn.variable(node)));
1508 kept_vars.insert(&(bn.variable(node)));
1515 Potential< GUM_SCALAR >* joint =
nullptr;
1517 if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1518 joint =
const_cast< Potential< GUM_SCALAR >*
>(*(new_pot_list.begin()));
1521 if (pot_list.exists(joint)) {
1522 joint =
new Potential< GUM_SCALAR >(*joint);
1526 new_pot_list.clear();
1532 for (
const auto node : hard_ev_nodes) {
1533 new_new_pot_list.insert(
evidence[node]);
1535 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1537 joint = fast_combination.combine(new_new_pot_list);
1541 for (
auto pot : new_pot_list)
1542 if (!pot_list.exists(pot))
delete pot;
1546 bool nonzero_found =
false;
1547 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1548 if ((*joint)[inst]) {
1549 nonzero_found =
true;
1553 if (!nonzero_found) {
1557 "some evidence entered into the Bayes " 1558 "net are incompatible (their joint proba = 0)");
1566 template <
typename GUM_SCALAR >
1567 const Potential< GUM_SCALAR >&
1570 if (__joint_target_posteriors.exists(
set)) {
1571 return *(__joint_target_posteriors[
set]);
1577 __joint_target_posteriors.insert(
set, joint);
1584 template <
typename GUM_SCALAR >
1585 const Potential< GUM_SCALAR >&
1589 if (__joint_target_posteriors.exists(wanted_target))
1590 return *(__joint_target_posteriors[wanted_target]);
1596 if (!__joint_target_posteriors.exists(declared_target)) {
1601 const auto& bn = this->
BN();
1602 Set< const DiscreteVariable* > del_vars;
1603 for (
const auto node : declared_target)
1604 if (!wanted_target.contains(node)) del_vars.insert(&(bn.variable(node)));
1605 auto pot =
new Potential< GUM_SCALAR >(
1606 __joint_target_posteriors[declared_target]->margSumOut(del_vars));
1609 __joint_target_posteriors.insert(wanted_target, pot);
1615 template <
typename GUM_SCALAR >
1626 GUM_SCALAR prob_ev = 1;
1627 for (
const auto root :
__roots) {
1629 NodeId node = *(__JT->clique(root).begin());
1632 for (Instantiation iter(*tmp); !iter.end(); ++iter)
1633 sum += tmp->get(iter);
1638 for (
const auto& projected_cpt : __constants)
1639 prob_ev *= projected_cpt.second;
1647 #endif // DOXYGEN_SHOULD_SKIP_THIS NodeProperty< const Potential< GUM_SCALAR > *> __clique_ss_potential
the potentials stored into the cliques by Shafer-Shenoy
void _onAllJointTargetsErased() final
fired before a all the joint targets are removed
void _onEvidenceChanged(const NodeId id, bool hasChangedSoftHard) final
fired after an evidence is changed, in particular when its status (soft/hard) changes ...
FindBarrenNodesType __barren_nodes_type
the type of barren nodes computation we wish
void __setProjectionFunction(Potential< GUM_SCALAR > *(*proj)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
sets the operator for performing the projections
virtual void clear()
removes all the nodes and edges from the graph
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
const NodeProperty< const Potential< GUM_SCALAR > *> & evidence() const
returns the set of evidence
NodeProperty< GUM_SCALAR > __constants
the constants resulting from the projections of CPTs defined over only hard evidence nodes remove th...
JunctionTree * __junctionTree
the junction tree to answer the last inference query
void __collectMessage(NodeId id, NodeId from)
actually perform the collect phase
void _onEvidenceErased(const NodeId id, bool isHardEvidence) final
fired before an evidence is removed
__PotentialSet __marginalizeOut(__PotentialSet pot_list, Set< const DiscreteVariable * > &del_vars, Set< const DiscreteVariable * > &kept_vars)
removes variables del_vars from a list of potentials and returns the resulting list ...
void _onMarginalTargetErased(const NodeId id) final
fired before a single target is removed
void _onMarginalTargetAdded(const NodeId id) final
fired after a new single target is inserted
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
ArcProperty< bool > __messages_computed
indicates whether a message (from one clique to another) has been computed
Set< const Potential< GUM_SCALAR > *> __PotentialSet
void __setCombinationFunction(Potential< GUM_SCALAR > *(*comb)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &))
sets the operator for performing the combinations
d-separation analysis (as described in Koller & Friedman 2009)
const GUM_SCALAR __1_minus_epsilon
for comparisons with 1 - epsilon
Implementation of Shafer-Shenoy's algorithm for inference in Bayesian Networks.
Potential< GUM_SCALAR > * _unnormalizedJointPosterior(NodeId id) final
returns a fresh potential equal to P(argument,evidence)
void _onAllMarginalTargetsAdded() final
fired after all the nodes of the BN are added as single targets
NodeSet __roots
a clique node used as a root in each connected component of __JT
The BayesBall algorithm (as described by Schachter).
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
void erase(const Key &k)
Erases an element from the set.
bool exists(const NodeId id) const
alias for existsNode
GUM_SCALAR evidenceProbability()
returns the probability of evidence
void _makeInference() final
called when the inference has to be performed effectively
gum is the global namespace for all aGrUM entities
HashTable< NodeSet, const Potential< GUM_SCALAR > *> __joint_target_posteriors
the set of set target posteriors computed during the last inference
void _updateOutdatedBNPotentials() final
prepares inference when the latter is in OutdatedBNPotentials state
UndiGraph __graph
the undigraph extracted from the BN and used to construct the join tree
virtual void eraseNode(const NodeId id)
remove a node and its adjacent edges from the graph
virtual void makeInference() final
perform the heavy computations needed to compute the targets' posteriors
ShaferShenoyInference(const IBayesNet< GUM_SCALAR > *BN, FindBarrenNodesType barren_type=FindBarrenNodesType::FIND_BARREN_NODES, bool use_binary_join_tree=true)
default constructor
NodeProperty< EvidenceChangeType > __evidence_changes
indicates which nodes of the BN have evidence that changed since the last inference ...
FindBarrenNodesType
type of algorithm to determine barren nodes
NodeProperty< const Potential< GUM_SCALAR > *> __node_to_soft_evidence
the soft evidence stored in the cliques per their assigned node in the BN
const NodeSet & softEvidenceNodes() const
returns the set of nodes with soft evidence
HashTable< NodeId, NodeId > __node_to_clique
for each node of __graph (~ in the Bayes net), associate an ID in the JT
virtual const NodeProperty< Size > & domainSizes() const final
get the domain sizes of the random variables of the BN
NodeProperty< const Potential< GUM_SCALAR > *> __target_posteriors
the set of single posteriors computed during the last inference
NodeProperty< const Potential< GUM_SCALAR > *> __hard_ev_projected_CPTs
the CPTs that were projected due to hard evidence nodes
virtual NodeId createdJunctionTreeClique(const NodeId id)=0
returns the Id of the clique created by the elimination of a given node during the triangulation proc...
void _onJointTargetAdded(const NodeSet &set) final
fired after a new joint target is inserted
bool __isNewJTNeeded() const
check whether a new join tree is really needed for the next inference
void _updateOutdatedBNStructure() final
prepares inference when the latter is in OutdatedBNStructure state
const NodeProperty< Idx > & hardEvidence() const
indicate for each node with hard evidence which value it took
virtual bool isDone() const noexcept final
returns whether the inference object is in a done state
void setTriangulation(const Triangulation &new_triangulation)
use a new triangulation algorithm
CliqueGraph JoinTree
a join tree is a clique graph satisfying the running intersection property (but some cliques may be i...
const Potential< GUM_SCALAR > & _jointPosterior(const NodeSet &set) final
returns the posterior of a declared target set
virtual const Set< NodeSet > & jointTargets() const noexcept final
returns the list of joint targets
void _onAllTargetsErased() final
fired before a all single and joint_targets are removed
virtual void _onBayesNetChanged(const IBayesNet< GUM_SCALAR > *bn) final
fired after a new Bayes net has been assigned to the engine
bool __is_new_jt_needed
indicates whether a new join tree is needed for the next inference
static INLINE Potential< GUM_SCALAR > * SSNewprojPotential(const Potential< GUM_SCALAR > &t1, const Set< const DiscreteVariable * > &del_vars)
Header files of gum::Instantiation.
void setFindBarrenNodesType(FindBarrenNodesType type)
sets how we determine barren nodes
void _onJointTargetErased(const NodeSet &set) final
fired before a joint target is removed
HashTable< NodeSet, NodeId > __joint_target_to_clique
for each set target, assign a clique in the JT that contains it
void _setOutdatedBNStructureState()
put the inference into an outdated BN structure state
~ShaferShenoyInference() final
destructor
const NodeSet & clique(const NodeId idClique) const
returns the set of nodes included into a given clique
ArcProperty< __PotentialSet > __created_potentials
the set of potentials created for the last inference messages
const NodeSet & hardEvidenceNodes() const
returns the set of nodes with hard evidence
Triangulation * __triangulation
the triangulation class creating the junction tree used for inference
void __createNewJT()
create a new junction tree as well as its related data structures
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...
Detect barren nodes for inference in Bayesian networks.
void clear()
Removes all the elements in the hash table.
void __diffuseMessageInvalidations(NodeId from, NodeId to, NodeSet &cliques_invalidated)
invalidate all the messages sent from a given clique
An algorithm for converting a join tree into a binary join tree.
__PotentialSet __removeBarrenVariables(__PotentialSet &pot_list, Set< const DiscreteVariable * > &del_vars)
JoinTree * __JT
the join (or junction) tree used to answer the last inference query
void __computeJoinTreeRoots()
compute a root for each connected component of __JT
virtual bool isInferenceReady() const noexcept final
returns whether the inference object is in a ready state
void _onEvidenceAdded(const NodeId id, bool isHardEvidence) final
fired after a new evidence is inserted
virtual Triangulation * newFactory() const =0
returns a fresh triangulation of the same type as the current object but with an empty graph ...
virtual const NodeSet & targets() const noexcept final
returns the list of marginal targets
NodeSet __hard_ev_nodes
the hard evidence nodes which were projected in CPTs
bool __use_binary_join_tree
indicates whether we should transform junction trees into binary join trees
void clear()
Removes all the elements, if any, from the set.
std::size_t Size
In aGrUM, hashed values are unsigned long int.
NodeProperty< __PotentialSet > __clique_potentials
the list of all potentials stored in the cliques
Potential< GUM_SCALAR > *(* __combination_op)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &)
the operator for performing the combinations
const Potential< GUM_SCALAR > & _posterior(NodeId id) final
returns the posterior of a given variable
virtual const CliqueGraph & junctionTree()=0
returns a compatible junction tree
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
const JoinTree * joinTree()
returns the current join tree used
virtual const std::vector< NodeId > & eliminationOrder()=0
returns an elimination ordering compatible with the triangulated graph
void _onAllEvidenceErased(bool contains_hard_evidence) final
fired before all the evidence are erased
void __invalidateAllMessages()
invalidate all messages, posteriors and created potentials
virtual const IBayesNet< GUM_SCALAR > & BN() const final
Returns a constant reference over the IBayesNet referenced by this class.
void _onAllMarginalTargetsErased() final
fired before a all the single targets are removed
Potential< GUM_SCALAR > *(* __projection_op)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &)
the operator for performing the projections
Size NodeId
Type for node ids.
void insert(const Key &k)
Inserts a new element into the set.
void __produceMessage(NodeId from_id, NodeId to_id)
creates the message sent by clique from_id to clique to_id
#define GUM_ERROR(type, msg)
virtual void setGraph(const UndiGraph *graph, const NodeProperty< Size > *domsizes)=0
initialize the triangulation data structures for a new graph
void _setOutdatedBNPotentialsState()
puts the inference into an OutdatedBNPotentials state if it is not already in an OutdatedBNStructure ...
ArcProperty< __PotentialSet > __separator_potentials
the list of all potentials stored in the separators after inferences
const JunctionTree * junctionTree()
returns the current junction tree