31 #ifndef DOXYGEN_SHOULD_SKIP_THIS 45 template <
typename GUM_SCALAR >
47 const IBayesNet< GUM_SCALAR >* BN,
50 bool use_binary_join_tree) :
51 JointTargetedInference< GUM_SCALAR >(BN),
52 EvidenceInference< GUM_SCALAR >(BN),
53 __use_binary_join_tree(use_binary_join_tree) {
67 template <
typename GUM_SCALAR >
71 for (
const auto pot : pots.second)
96 template <
typename GUM_SCALAR >
98 const Triangulation& new_triangulation) {
107 template <
typename GUM_SCALAR >
115 template <
typename GUM_SCALAR >
124 template <
typename GUM_SCALAR >
151 "setRelevantPotentialsFinderType for type " 152 << (
unsigned int)type <<
" is not implemented yet");
165 template <
typename GUM_SCALAR >
167 Potential< GUM_SCALAR >* (*proj)(
const Potential< GUM_SCALAR >&,
168 const Set< const DiscreteVariable* >&)) {
174 template <
typename GUM_SCALAR >
176 Potential< GUM_SCALAR >* (*comb)(
const Potential< GUM_SCALAR >&,
177 const Potential< GUM_SCALAR >&)) {
183 template <
typename GUM_SCALAR >
187 potset.second.clear();
189 mess_computed.second =
false;
192 for (
const auto& potset : __created_potentials)
193 for (
const auto pot : potset.second)
197 for (
const auto& pot : __target_posteriors)
199 for (
const auto& pot : __joint_target_posteriors)
209 template <
typename GUM_SCALAR >
222 "setFindBarrenNodesType for type " 223 << (
unsigned int)type <<
" is not implemented yet");
235 template <
typename GUM_SCALAR >
238 bool isHardEvidence) {
247 }
catch (DuplicateElement&) {
259 template <
typename GUM_SCALAR >
262 bool isHardEvidence) {
270 }
catch (DuplicateElement&) {
286 template <
typename GUM_SCALAR >
295 }
catch (DuplicateElement&) {
312 template <
typename GUM_SCALAR >
315 bool hasChangedSoftHard) {
316 if (hasChangedSoftHard)
321 }
catch (DuplicateElement&) {
331 template <
typename GUM_SCALAR >
333 const IBayesNet< GUM_SCALAR >* bn) {}
337 template <
typename GUM_SCALAR >
343 template <
typename GUM_SCALAR >
349 template <
typename GUM_SCALAR >
355 template <
typename GUM_SCALAR >
361 template <
typename GUM_SCALAR >
366 template <
typename GUM_SCALAR >
371 template <
typename GUM_SCALAR >
376 template <
typename GUM_SCALAR >
381 template <
typename GUM_SCALAR >
394 for (
const auto node : this->
targets()) {
395 if (!
__graph.
exists(node) && !hard_ev_nodes.exists(node))
return true;
400 bool containing_clique_found =
false;
401 for (
const auto node : joint_target) {
405 for (
const auto xnode : joint_target) {
406 if (!clique.contains(xnode) && !hard_ev_nodes.exists(xnode)) {
411 }
catch (NotFound&) { found =
false; }
414 containing_clique_found =
true;
419 if (!containing_clique_found)
return true;
425 if ((change.second == EvidenceChangeType::EVIDENCE_ADDED)
436 template <
typename GUM_SCALAR >
452 const auto& bn = this->
BN();
454 for (
const auto node : bn.dag())
467 target_nodes += nodeset;
472 if (target_nodes.size() != bn.size()) {
473 BarrenNodesFinder finder(&(bn.dag()));
474 finder.setTargets(&target_nodes);
477 for (
const auto& pair : this->
evidence()) {
478 evidence_nodes.
insert(pair.first);
480 finder.setEvidence(&evidence_nodes);
482 NodeSet barren_nodes = finder.barrenNodes();
485 for (
const auto node : barren_nodes) {
492 for (
const auto node :
__graph) {
493 const NodeSet& parents = bn.parents(node);
494 for (
auto iter1 = parents.cbegin(); iter1 != parents.cend(); ++iter1) {
495 __graph.addEdge(*iter1, node);
497 for (++iter2; iter2 != parents.cend(); ++iter2) {
498 __graph.addEdge(*iter1, *iter2);
507 for (
auto iter1 = nodeset.cbegin(); iter1 != nodeset.cend(); ++iter1) {
509 for (++iter2; iter2 != nodeset.cend(); ++iter2) {
510 __graph.addEdge(*iter1, *iter2);
518 __graph.eraseNode(node);
531 BinaryJoinTreeConverterDefault bjt_converter;
533 __JT =
new CliqueGraph(
534 bjt_converter.convert(triang_jt, this->domainSizes(), emptyset));
536 __JT =
new CliqueGraph(triang_jt);
544 const std::vector< NodeId >& JT_elim_order =
546 NodeProperty< int > elim_order(
Size(JT_elim_order.size()));
547 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size;
549 elim_order.insert(JT_elim_order[i], (
int)i);
550 const DAG& dag = bn.dag();
551 for (
const auto node : __graph) {
553 NodeId first_eliminated_node = node;
554 int elim_number = elim_order[first_eliminated_node];
556 for (
const auto parent : dag.parents(node)) {
557 if (__graph.existsNode(parent) && (elim_order[parent] < elim_number)) {
558 elim_number = elim_order[parent];
559 first_eliminated_node = parent;
574 for (
const auto node : __hard_ev_nodes) {
576 NodeSet pars(dag.parents(node).size());
577 for (
const auto par : dag.parents(node))
578 if (__graph.exists(par)) pars.
insert(par);
581 NodeId first_eliminated_node = *(pars.begin());
582 int elim_number = elim_order[first_eliminated_node];
584 for (
const auto parent : pars) {
585 if (elim_order[parent] < elim_number) {
586 elim_number = elim_order[parent];
587 first_eliminated_node = parent;
607 for (
const auto node : __hard_ev_nodes)
608 if (nodeset.contains(node)) nodeset.
erase(node);
610 if (!nodeset.empty()) {
613 NodeId first_eliminated_node = *(nodeset.begin());
614 int elim_number = elim_order[first_eliminated_node];
615 for (
const auto node : nodeset) {
616 if (elim_order[node] < elim_number) {
617 elim_number = elim_order[node];
618 first_eliminated_node = node;
636 for (
const auto node : *
__JT) {
641 for (
const auto& potlist : __created_potentials)
642 for (
const auto pot : potlist.second)
644 __created_potentials.clear();
648 for (
const auto pot_pair : __hard_ev_projected_CPTs)
649 delete pot_pair.second;
650 __hard_ev_projected_CPTs.clear();
658 __separator_potentials.clear();
659 __messages_computed.clear();
660 for (
const auto& edge : __JT->edges()) {
661 const Arc arc1(edge.first(), edge.second());
662 __separator_potentials.insert(arc1, empty_set);
663 __messages_computed.insert(arc1,
false);
664 const Arc arc2(Arc(edge.second(), edge.first()));
665 __separator_potentials.insert(arc2, empty_set);
666 __messages_computed.insert(arc2,
false);
670 for (
const auto& pot : __target_posteriors)
672 __target_posteriors.clear();
673 for (
const auto& pot : __joint_target_posteriors)
675 __joint_target_posteriors.clear();
684 for (
const auto node : dag) {
685 if (__graph.exists(node) || __hard_ev_nodes.contains(node)) {
686 const Potential< GUM_SCALAR >& cpt = bn.cpt(node);
690 const auto& variables = cpt.variablesSequence();
691 for (
const auto var : variables) {
692 NodeId xnode = bn.nodeId(*var);
693 if (__hard_ev_nodes.contains(xnode)) hard_nodes.insert(xnode);
699 if (hard_nodes.empty()) {
705 if (hard_nodes.size() == variables.size()) {
707 const auto& vars = cpt.variablesSequence();
708 for (
const auto var : vars)
710 for (
Size i = 0; i < hard_nodes.size(); ++i) {
711 inst.chgVal(variables[i], hard_evidence[bn.nodeId(*(variables[i]))]);
716 Set< const DiscreteVariable* > hard_variables;
718 for (
const auto xnode : hard_nodes) {
719 marg_cpt_set.insert(
evidence[xnode]);
720 hard_variables.insert(&(bn.variable(xnode)));
724 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential >
727 combine_and_project.combineAndProject(marg_cpt_set, hard_variables);
730 if (new_cpt_list.size() != 1) {
732 for (
const auto pot : new_cpt_list) {
733 if (!marg_cpt_set.contains(pot))
delete pot;
736 "the projection of a potential containing " 737 <<
"hard evidence is empty!");
739 const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
741 __hard_ev_projected_CPTs.insert(node, projected_cpt);
754 __evidence_changes.clear();
760 template <
typename GUM_SCALAR >
775 template <
typename GUM_SCALAR >
779 invalidated_cliques.insert(to_id);
782 const Arc arc(from_id, to_id);
783 bool& message_computed = __messages_computed[arc];
784 if (message_computed) {
785 message_computed =
false;
786 __separator_potentials[arc].clear();
787 if (__created_potentials.exists(arc)) {
788 auto& arc_created_potentials = __created_potentials[arc];
789 for (
const auto pot : arc_created_potentials)
791 arc_created_potentials.clear();
795 for (
const auto node_id : __JT->neighbours(to_id)) {
796 if (node_id != from_id)
805 template <
typename GUM_SCALAR >
814 NodeSet hard_nodes_changed(__hard_ev_nodes.size());
815 for (
const auto node : __hard_ev_nodes)
816 if (__evidence_changes.exists(node)) hard_nodes_changed.
insert(node);
818 NodeSet nodes_with_projected_CPTs_changed;
819 const auto& bn = this->
BN();
820 for (
auto pot_iter = __hard_ev_projected_CPTs.beginSafe();
821 pot_iter != __hard_ev_projected_CPTs.endSafe();
823 for (
const auto var : bn.cpt(pot_iter.key()).variablesSequence()) {
824 if (hard_nodes_changed.contains(bn.nodeId(*var))) {
825 nodes_with_projected_CPTs_changed.insert(pot_iter.key());
826 delete pot_iter.val();
829 __hard_ev_projected_CPTs.erase(pot_iter);
843 NodeSet invalidated_cliques(__JT->size());
844 for (
const auto& pair : __evidence_changes) {
847 invalidated_cliques.
insert(clique);
848 for (
const auto neighbor : __JT->neighbours(clique)) {
856 for (
const auto node : nodes_with_projected_CPTs_changed) {
858 invalidated_cliques.
insert(clique);
859 for (
const auto neighbor : __JT->neighbours(clique)) {
869 for (
auto iter = __target_posteriors.beginSafe();
870 iter != __target_posteriors.endSafe();
872 if (__graph.exists(iter.key())
875 __target_posteriors.erase(iter);
880 for (
auto iter = __target_posteriors.beginSafe();
881 iter != __target_posteriors.endSafe();
883 if (hard_nodes_changed.contains(iter.key())) {
885 __target_posteriors.erase(iter);
890 for (
auto iter = __joint_target_posteriors.beginSafe();
891 iter != __joint_target_posteriors.endSafe();
895 __joint_target_posteriors.erase(iter);
905 __node_to_soft_evidence.clear();
909 __node_to_soft_evidence.insert(node,
evidence[node]);
920 for (
const auto node : nodes_with_projected_CPTs_changed) {
922 const Potential< GUM_SCALAR >& cpt = bn.cpt(node);
923 const auto& variables = cpt.variablesSequence();
924 Set< const DiscreteVariable* > hard_variables;
926 for (
const auto var : variables) {
927 NodeId xnode = bn.nodeId(*var);
928 if (__hard_ev_nodes.exists(xnode)) {
929 marg_cpt_set.insert(
evidence[xnode]);
930 hard_variables.insert(var);
935 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential >
938 combine_and_project.combineAndProject(marg_cpt_set, hard_variables);
941 if (new_cpt_list.size() != 1) {
943 for (
const auto pot : new_cpt_list) {
944 if (!marg_cpt_set.contains(pot))
delete pot;
947 "the projection of a potential containing " 948 <<
"hard evidence is empty!");
950 const Potential< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
952 __hard_ev_projected_CPTs.insert(node, projected_cpt);
958 const Potential< GUM_SCALAR >& cpt = bn.cpt(node_cst.first);
959 const auto& variables = cpt.variablesSequence();
961 for (
const auto var : variables)
963 for (
const auto var : variables) {
964 inst.chgVal(var, hard_evidence[bn.nodeId(*var)]);
966 node_cst.second = cpt.get(inst);
970 __evidence_changes.clear();
975 template <
typename GUM_SCALAR >
979 for (
const auto node : this->
targets()) {
982 }
catch (Exception&) {}
987 }
catch (Exception&) {}
991 std::vector< std::pair< NodeId, Size > > possible_roots(clique_targets.size());
992 const auto& bn = this->
BN();
994 for (
const auto clique_id : clique_targets) {
995 const auto& clique = __JT->clique(clique_id);
997 for (
const auto node : clique) {
998 dom_size *= bn.variable(node).domainSize();
1000 possible_roots[i] = std::pair< NodeId, Size >(clique_id, dom_size);
1005 std::sort(possible_roots.begin(),
1006 possible_roots.end(),
1007 [](
const std::pair< NodeId, Size >& a,
1008 const std::pair< NodeId, Size >& b) ->
bool {
1009 return a.second < b.second;
1013 NodeProperty< bool > marked = __JT->nodesProperty(
false);
1014 std::function< void(NodeId, NodeId) > diffuse_marks =
1015 [&marked, &diffuse_marks,
this](
NodeId node,
NodeId from) {
1016 if (!marked[node]) {
1017 marked[node] =
true;
1018 for (
const auto neigh : __JT->neighbours(node))
1019 if ((neigh != from) && !marked[neigh]) diffuse_marks(neigh, node);
1023 for (
const auto xclique : possible_roots) {
1024 NodeId clique = xclique.first;
1025 if (!marked[clique]) {
1027 diffuse_marks(clique, clique);
1034 template <
typename GUM_SCALAR >
1037 for (
const auto other : __JT->neighbours(
id)) {
1038 if ((other != from) && !__messages_computed[Arc(other,
id)])
1042 if ((
id != from) && !__messages_computed[Arc(
id, from)]) {
1049 template <
typename GUM_SCALAR >
1051 Set<
const Potential< GUM_SCALAR >* >& pot_list,
1052 Set< const DiscreteVariable* >& kept_vars) {}
1056 template <
typename GUM_SCALAR >
1058 Set<
const Potential< GUM_SCALAR >* >& pot_list,
1059 Set< const DiscreteVariable* >& kept_vars) {
1062 const auto& bn = this->
BN();
1063 for (
const auto var : kept_vars) {
1064 kept_ids.insert(bn.nodeId(*var));
1074 for (
auto iter = pot_list.beginSafe(); iter != pot_list.endSafe(); ++iter) {
1075 const Sequence< const DiscreteVariable* >& vars =
1076 (**iter).variablesSequence();
1078 for (
const auto var : vars) {
1079 if (requisite_nodes.exists(bn.nodeId(*var))) {
1085 if (!found) { pot_list.erase(iter); }
1091 template <
typename GUM_SCALAR >
1093 Set<
const Potential< GUM_SCALAR >* >& pot_list,
1094 Set< const DiscreteVariable* >& kept_vars) {
1097 const auto& bn = this->
BN();
1098 for (
const auto var : kept_vars) {
1099 kept_ids.insert(bn.nodeId(*var));
1112 template <
typename GUM_SCALAR >
1114 Set<
const Potential< GUM_SCALAR >* >& pot_list,
1115 Set< const DiscreteVariable* >& kept_vars) {
1118 const auto& bn = this->
BN();
1119 for (
const auto var : kept_vars) {
1120 kept_ids.insert(bn.nodeId(*var));
1125 dsep.relevantPotentials(bn,
1134 template <
typename GUM_SCALAR >
1136 Set<
const Potential< GUM_SCALAR >* >& pot_list,
1137 Set< const DiscreteVariable* >& kept_vars) {
1155 default:
GUM_ERROR(FatalError,
"not implemented yet");
1161 template <
typename GUM_SCALAR >
1162 Set< const Potential< GUM_SCALAR >* >
1164 __PotentialSet& pot_list, Set< const DiscreteVariable* >& del_vars) {
1167 Set< const DiscreteVariable* > the_del_vars = del_vars;
1168 for (
auto iter = the_del_vars.beginSafe(); iter != the_del_vars.endSafe();
1170 NodeId id = this->
BN().nodeId(**iter);
1173 the_del_vars.erase(iter);
1178 HashTable< const DiscreteVariable*, __PotentialSet > var2pots;
1180 for (
const auto pot : pot_list) {
1181 const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
1182 for (
const auto var : vars) {
1183 if (the_del_vars.exists(var)) {
1184 if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
1185 var2pots[var].insert(pot);
1192 HashTable< const Potential< GUM_SCALAR >*, Set< const DiscreteVariable* > >
1194 Set< const DiscreteVariable* > empty_var_set;
1195 for (
const auto elt : var2pots) {
1196 if (elt.second.size() == 1) {
1197 const Potential< GUM_SCALAR >* pot = *(elt.second.begin());
1198 if (!pot2barren_var.exists(pot)) {
1199 pot2barren_var.insert(pot, empty_var_set);
1201 pot2barren_var[pot].insert(elt.first);
1210 for (
const auto elt : pot2barren_var) {
1213 const Potential< GUM_SCALAR >* pot = elt.first;
1214 pot_list.erase(pot);
1218 if (pot->variablesSequence().size() != elt.second.size()) {
1219 auto new_pot = projector.project(*pot, elt.second);
1220 pot_list.insert(new_pot);
1221 projected_pots.insert(new_pot);
1225 return projected_pots;
1230 template <
typename GUM_SCALAR >
1231 Set< const Potential< GUM_SCALAR >* >
1233 Set<
const Potential< GUM_SCALAR >* > pot_list,
1234 Set< const DiscreteVariable* >& del_vars,
1235 Set< const DiscreteVariable* >& kept_vars) {
1248 MultiDimCombineAndProjectDefault< GUM_SCALAR, Potential > combine_and_project(
1251 combine_and_project.combineAndProject(pot_list, del_vars);
1256 for (
auto iter = barren_projected_potentials.beginSafe();
1257 iter != barren_projected_potentials.endSafe();
1259 if (!new_pot_list.exists(*iter))
delete *iter;
1263 for (
auto iter_pot = new_pot_list.beginSafe();
1264 iter_pot != new_pot_list.endSafe();
1266 if ((*iter_pot)->variablesSequence().size() == 0) {
1273 new_pot_list.erase(iter_pot);
1277 return new_pot_list;
1282 template <
typename GUM_SCALAR >
1289 for (
const auto other_id : __JT->neighbours(from_id))
1290 if (other_id != to_id)
1291 pot_list += __separator_potentials[Arc(other_id, from_id)];
1294 const NodeSet& from_clique = __JT->clique(from_id);
1295 const NodeSet& separator = __JT->separator(from_id, to_id);
1296 Set< const DiscreteVariable* > del_vars(from_clique.size());
1297 Set< const DiscreteVariable* > kept_vars(separator.size());
1298 const auto& bn = this->
BN();
1300 for (
const auto node : from_clique) {
1301 if (!separator.contains(node)) {
1302 del_vars.insert(&(bn.variable(node)));
1304 kept_vars.insert(&(bn.variable(node)));
1315 const Arc arc(from_id, to_id);
1316 for (
auto iter = new_pot_list.beginSafe(); iter != new_pot_list.endSafe();
1318 const auto pot = *iter;
1319 if (pot->variablesSequence().size() == 1) {
1320 bool is_all_ones =
true;
1321 for (Instantiation inst(*pot); !inst.end(); ++inst) {
1323 is_all_ones =
false;
1328 if (!pot_list.exists(pot))
delete pot;
1329 new_pot_list.erase(iter);
1334 if (!pot_list.exists(pot)) {
1335 if (!__created_potentials.exists(arc))
1337 __created_potentials[arc].insert(pot);
1341 __separator_potentials[arc] = std::move(new_pot_list);
1342 __messages_computed[arc] =
true;
1347 template <
typename GUM_SCALAR >
1350 for (
const auto node : this->
targets()) {
1354 if (__graph.exists(node)) {
1369 template <
typename GUM_SCALAR >
1370 Potential< GUM_SCALAR >*
1372 const auto& bn = this->
BN();
1377 return new Potential< GUM_SCALAR >(*(this->
evidence()[id]));
1391 for (
const auto other : __JT->neighbours(clique_of_id))
1392 pot_list += __separator_potentials[Arc(other, clique_of_id)];
1395 const NodeSet& nodes = __JT->clique(clique_of_id);
1396 Set< const DiscreteVariable* > kept_vars{&(bn.variable(
id))};
1397 Set< const DiscreteVariable* > del_vars(nodes.size());
1398 for (
const auto node : nodes) {
1399 if (node !=
id) del_vars.
insert(&(bn.variable(node)));
1405 Potential< GUM_SCALAR >* joint =
nullptr;
1407 if (new_pot_list.size() == 1) {
1408 joint =
const_cast< Potential< GUM_SCALAR >*
>(*(new_pot_list.begin()));
1411 if (pot_list.exists(joint)) {
1412 joint =
new Potential< GUM_SCALAR >(*joint);
1416 new_pot_list.clear();
1419 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1421 joint = fast_combination.combine(new_pot_list);
1425 for (
const auto pot : new_pot_list)
1426 if (!pot_list.exists(pot))
delete pot;
1431 bool nonzero_found =
false;
1432 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1433 if (joint->get(inst)) {
1434 nonzero_found =
true;
1438 if (!nonzero_found) {
1442 "some evidence entered into the Bayes " 1443 "net are incompatible (their joint proba = 0)");
1450 template <
typename GUM_SCALAR >
1451 const Potential< GUM_SCALAR >&
1454 if (__target_posteriors.exists(
id)) {
return *(__target_posteriors[id]); }
1459 __target_posteriors.insert(
id, joint);
1466 template <
typename GUM_SCALAR >
1467 Potential< GUM_SCALAR >*
1474 if (targets.contains(node)) {
1475 targets.
erase(node);
1476 hard_ev_nodes.insert(node);
1483 if (targets.empty()) {
1485 for (
const auto node :
set) {
1488 if (pot_list.size() == 1) {
1489 auto pot =
new Potential< GUM_SCALAR >(**(pot_list.begin()));
1492 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1494 return fast_combination.combine(pot_list);
1505 }
catch (NotFound&) {
1511 for (
const auto node : targets) {
1512 if (!__graph.exists(node)) {
1513 GUM_ERROR(UndefinedElement, node <<
" is not a target node");
1519 const std::vector< NodeId >& JT_elim_order =
1522 NodeProperty< int > elim_order(
Size(JT_elim_order.size()));
1523 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size;
1525 elim_order.insert(JT_elim_order[i], (
int)i);
1526 NodeId first_eliminated_node = *(targets.begin());
1527 int elim_number = elim_order[first_eliminated_node];
1528 for (
const auto node : targets) {
1529 if (elim_order[node] < elim_number) {
1530 elim_number = elim_order[node];
1531 first_eliminated_node = node;
1540 const NodeSet& clique_nodes = __JT->clique(clique_of_set);
1541 for (
const auto node : targets) {
1542 if (!clique_nodes.contains(node)) {
1543 GUM_ERROR(UndefinedElement,
set <<
" is not a joint target");
1560 for (
const auto other : __JT->neighbours(clique_of_set))
1561 pot_list += __separator_potentials[Arc(other, clique_of_set)];
1564 const NodeSet& nodes = __JT->clique(clique_of_set);
1565 Set< const DiscreteVariable* > del_vars(nodes.size());
1566 Set< const DiscreteVariable* > kept_vars(targets.size());
1567 const auto& bn = this->
BN();
1568 for (
const auto node : nodes) {
1569 if (!targets.contains(node)) {
1570 del_vars.insert(&(bn.variable(node)));
1572 kept_vars.insert(&(bn.variable(node)));
1579 Potential< GUM_SCALAR >* joint =
nullptr;
1581 if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1582 joint =
const_cast< Potential< GUM_SCALAR >*
>(*(new_pot_list.begin()));
1586 if (pot_list.exists(joint)) {
1587 joint =
new Potential< GUM_SCALAR >(*joint);
1591 new_pot_list.clear();
1597 for (
const auto node : hard_ev_nodes) {
1598 new_new_pot_list.insert(
evidence[node]);
1600 MultiDimCombinationDefault< GUM_SCALAR, Potential > fast_combination(
1602 joint = fast_combination.combine(new_new_pot_list);
1606 for (
const auto pot : new_pot_list)
1607 if (!pot_list.exists(pot))
delete pot;
1611 bool nonzero_found =
false;
1612 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1613 if ((*joint)[inst]) {
1614 nonzero_found =
true;
1618 if (!nonzero_found) {
1622 "some evidence entered into the Bayes " 1623 "net are incompatible (their joint proba = 0)");
1631 template <
typename GUM_SCALAR >
1632 const Potential< GUM_SCALAR >&
1635 if (__joint_target_posteriors.exists(
set)) {
1636 return *(__joint_target_posteriors[
set]);
1642 __joint_target_posteriors.insert(
set, joint);
1649 template <
typename GUM_SCALAR >
1653 if (__joint_target_posteriors.exists(wanted_target))
1654 return *(__joint_target_posteriors[wanted_target]);
1660 if (!__joint_target_posteriors.exists(declared_target)) {
1665 const auto& bn = this->
BN();
1666 Set< const DiscreteVariable* > del_vars;
1667 for (
const auto node : declared_target)
1668 if (!wanted_target.contains(node)) del_vars.insert(&(bn.variable(node)));
1669 Potential< GUM_SCALAR >* pot =
new Potential< GUM_SCALAR >(
1670 __joint_target_posteriors[declared_target]->margSumOut(del_vars));
1673 __joint_target_posteriors.insert(wanted_target, pot);
1679 template <
typename GUM_SCALAR >
1704 GUM_SCALAR prob_ev = 1;
1705 for (
const auto root :
__roots) {
1707 NodeId node = *(__JT->clique(root).begin());
1710 for (Instantiation iter(*tmp); !iter.end(); ++iter)
1711 sum += tmp->get(iter);
1716 for (
const auto& projected_cpt : __constants)
1717 prob_ev *= projected_cpt.second;
1727 #endif // DOXYGEN_SHOULD_SKIP_THIS ~LazyPropagation() final
destructor
NodeProperty< const Potential< GUM_SCALAR > *> __node_to_soft_evidence
the soft evidence stored in the cliques per their assigned node in the BN
HashTable< NodeSet, const Potential< GUM_SCALAR > *> __joint_target_posteriors
the set of set target posteriors computed during the last inference
ArcProperty< bool > __messages_computed
indicates whether a message (from one clique to another) has been computed
void setFindBarrenNodesType(FindBarrenNodesType type)
sets how we determine barren nodes
NodeProperty< const Potential< GUM_SCALAR > *> __hard_ev_projected_CPTs
the CPTs that were projected due to hard evidence nodes
Set< const Potential< GUM_SCALAR > *> __PotentialSet
void _updateOutdatedBNStructure() final
prepares inference when the latter is in OutdatedBNStructure state
virtual void clear()
removes all the nodes and edges from the graph
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
const NodeProperty< const Potential< GUM_SCALAR > *> & evidence() const
returns the set of evidence
bool __isNewJTNeeded() const
check whether a new join tree is really needed for the next inference
void setRelevantPotentialsFinderType(RelevantPotentialsFinderType type)
sets how we determine the relevant potentials to combine
JunctionTree * __junctionTree
the junction tree to answer the last inference query
Triangulation * __triangulation
the triangulation class creating the junction tree used for inference
void _onAllEvidenceErased(bool has_hard_evidence) final
fired before all the evidence are erased
void __findRelevantPotentialsWithdSeparation3(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
bool __is_new_jt_needed
indicates whether a new join tree is needed for the next inference
Potential< GUM_SCALAR > * _unnormalizedJointPosterior(NodeId id) final
returns a fresh potential equal to P(argument,evidence)
ArcProperty< __PotentialSet > __created_potentials
the set of potentials created for the last inference messages
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
void setTriangulation(const Triangulation &new_triangulation)
use a new triangulation algorithm
void __findRelevantPotentialsXX(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
static void relevantPotentials(const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE< GUM_SCALAR > * > &potentials)
update a set of potentials, keeping only those d-connected with query variables given evidence ...
void __findRelevantPotentialsWithdSeparation2(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
NodeSet __hard_ev_nodes
the hard evidence nodes which were projected in CPTs
void erase(const Key &k)
Erases an element from the set.
bool exists(const NodeId id) const
alias for existsNode
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
GUM_SCALAR evidenceProbability() final
returns the probability of evidence
void _onJointTargetAdded(const NodeSet &set) final
fired after a new joint target is inserted
void _onAllJointTargetsErased() final
fired before a all the joint targets are removed
const Potential< GUM_SCALAR > & _jointPosterior(const NodeSet &set) final
returns the posterior of a declared target set
RelevantPotentialsFinderType
type of algorithm for determining the relevant potentials for combinations using some d-separation an...
virtual void eraseNode(const NodeId id)
remove a node and its adjacent edges from the graph
FindBarrenNodesType __barren_nodes_type
the type of barren nodes computation we wish
NodeProperty< const Potential< GUM_SCALAR > *> __target_posteriors
the set of single posteriors computed during the last inference
virtual void makeInference() final
perform the heavy computations needed to compute the targets' posteriors
const GUM_SCALAR __1_minus_epsilon
for comparisons with 1 - epsilon
void _onAllTargetsErased() final
fired before a all single and joint_targets are removed
NodeSet __roots
a clique node used as a root in each connected component of __JT
void _onMarginalTargetErased(const NodeId id) final
fired before a single target is removed
virtual void _onBayesNetChanged(const IBayesNet< GUM_SCALAR > *bn) final
fired after a new Bayes net has been assigned to the engine
void _onEvidenceChanged(const NodeId id, bool hasChangedSoftHard) final
fired after an evidence is changed, in particular when its status (soft/hard) changes ...
FindBarrenNodesType
type of algorithm to determine barren nodes
const Potential< GUM_SCALAR > & _posterior(NodeId id) final
returns the posterior of a given variable
const NodeSet & softEvidenceNodes() const
returns the set of nodes with soft evidence
void __collectMessage(NodeId id, NodeId from)
actually perform the collect phase
ArcProperty< __PotentialSet > __separator_potentials
the list of all potentials stored in the separators after inferences
virtual const NodeProperty< Size > & domainSizes() const final
get the domain sizes of the random variables of the BN
virtual NodeId createdJunctionTreeClique(const NodeId id)=0
returns the Id of the clique created by the elimination of a given node during the triangulation proc...
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const NodeProperty< Idx > & hardEvidence() const
indicate for each node with hard evidence which value it took
virtual bool isDone() const noexcept final
returns whether the inference object is in a done state
const JoinTree * joinTree()
returns the current join tree used
NodeProperty< GUM_SCALAR > __constants
the constants resulting from the projections of CPTs defined over only hard evidence nodes remove th...
JoinTree * __JT
the join (or junction) tree used to answer the last inference query
CliqueGraph JoinTree
a join tree is a clique graph satisfying the running intersection property (but some cliques may be i...
void _updateOutdatedBNPotentials() final
prepares inference when the latter is in OutdatedBNPotentials state
void _onEvidenceAdded(const NodeId id, bool isHardEvidence) final
fired after a new evidence is inserted
virtual const Set< NodeSet > & jointTargets() const noexcept final
returns the list of joint targets
void _makeInference() final
called when the inference has to be performed effectively
void __diffuseMessageInvalidations(NodeId from_id, NodeId to_id, NodeSet &invalidated_cliques)
invalidate all the messages sent from a given clique
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
void _onAllMarginalTargetsAdded() final
fired after all the nodes of the BN are added as single targets
void __setProjectionFunction(Potential< GUM_SCALAR > *(*proj)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &))
sets the operator for performing the projections
void _setOutdatedBNStructureState()
put the inference into an outdated BN structure state
const NodeSet & clique(const NodeId idClique) const
returns the set of nodes included into a given clique
const NodeSet & hardEvidenceNodes() const
returns the set of nodes with hard evidence
NodeProperty< __PotentialSet > __clique_potentials
the list of all potentials stored in the cliques
void __computeJoinTreeRoots()
compute a root for each connected component of __JT
HashTable< NodeSet, NodeId > __joint_target_to_clique
for each set target, assign a clique in the JT that contains it
__PotentialSet __removeBarrenVariables(__PotentialSet &pot_list, Set< const DiscreteVariable * > &del_vars)
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...
void __createNewJT()
create a new junction tree as well as its related data structures
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
static INLINE Potential< GUM_SCALAR > * LPNewprojPotential(const Potential< GUM_SCALAR > &t1, const Set< const DiscreteVariable * > &del_vars)
bool __use_binary_join_tree
indicates whether we should transform junction trees into binary join trees
void __invalidateAllMessages()
invalidate all messages, posteriors and created potentials
Potential< GUM_SCALAR > *(* __projection_op)(const Potential< GUM_SCALAR > &, const Set< const DiscreteVariable * > &)
the operator for performing the projections
__PotentialSet __marginalizeOut(__PotentialSet pot_list, Set< const DiscreteVariable * > &del_vars, Set< const DiscreteVariable * > &kept_vars)
removes variables del_vars from a list of potentials and returns the resulting list ...
void clear()
Removes all the elements in the hash table.
LazyPropagation(const IBayesNet< GUM_SCALAR > *BN, RelevantPotentialsFinderType=RelevantPotentialsFinderType::DSEP_BAYESBALL_POTENTIALS, FindBarrenNodesType=FindBarrenNodesType::FIND_BARREN_NODES, bool use_binary_join_tree=true)
default constructor
const JunctionTree * junctionTree()
returns the current junction tree
void _onJointTargetErased(const NodeSet &set) final
fired before a joint target is removed
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
virtual bool isInferenceReady() const noexcept final
returns whether the inference object is in a ready state
virtual Triangulation * newFactory() const =0
returns a fresh triangulation of the same type as the current object but with an empty graph ...
void _onAllMarginalTargetsErased() final
fired before a all the single targets are removed
virtual const NodeSet & targets() const noexcept final
returns the list of marginal targets
HashTable< NodeId, NodeId > __node_to_clique
for each node of __graph (~ in the Bayes net), associate an ID in the JT
void clear()
Removes all the elements, if any, from the set.
void __produceMessage(NodeId from_id, NodeId to_id)
creates the message sent by clique from_id to clique to_id
std::size_t Size
In aGrUM, hashed values are unsigned long int.
RelevantPotentialsFinderType __find_relevant_potential_type
the type of relevant potential finding algorithm to be used
void(LazyPropagation< GUM_SCALAR >::* __findRelevantPotentials)(Set< const Potential< GUM_SCALAR > * > &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
virtual const CliqueGraph & junctionTree()=0
returns a compatible junction tree
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
void __findRelevantPotentialsGetAll(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...
virtual const std::vector< NodeId > & eliminationOrder()=0
returns an elimination ordering compatible with the triangulated graph
void _onMarginalTargetAdded(const NodeId id) final
fired after a new single target is inserted
virtual const IBayesNet< GUM_SCALAR > & BN() const final
Returns a constant reference over the IBayesNet referenced by this class.
Size NodeId
Type for node ids.
void insert(const Key &k)
Inserts a new element into the set.
UndiGraph __graph
the undigraph extracted from the BN and used to construct the join tree
void __setCombinationFunction(Potential< GUM_SCALAR > *(*comb)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &))
sets the operator for performing the combinations
void _onEvidenceErased(const NodeId id, bool isHardEvidence) final
fired before an evidence is removed
#define GUM_ERROR(type, msg)
virtual void setGraph(const UndiGraph *graph, const NodeProperty< Size > *domsizes)=0
initialize the triangulation data structures for a new graph
static void requisiteNodes(const DAG &dag, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, NodeSet &requisite)
Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence.
Potential< GUM_SCALAR > *(* __combination_op)(const Potential< GUM_SCALAR > &, const Potential< GUM_SCALAR > &)
the operator for performing the combinations
void _setOutdatedBNPotentialsState()
puts the inference into an OutdatedBNPotentials state if it is not already in an OutdatedBNStructure ...
NodeProperty< EvidenceChangeType > __evidence_changes
indicates which nodes of the BN have evidence that changed since the last inference ...
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
void __findRelevantPotentialsWithdSeparation(__PotentialSet &pot_list, Set< const DiscreteVariable * > &kept_vars)
update a set of potentials: the remaining are those to be combined to produce a message on a separato...