63 ApproximationScheme::operator=(from);
69 ApproximationScheme::operator=(std::move(from));
81 return e1.second > e2.second;
85 const std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double >& e1,
86 const std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double >& e2)
88 return std::abs(e1.second) > std::abs(e2.second);
93 tuple< std::tuple< NodeId, NodeId, NodeId >*,
double,
double,
double >&
96 tuple< std::tuple< NodeId, NodeId, NodeId >*,
double,
double,
double >&
98 double p1xz = std::get< 2 >(e1);
99 double p1yz = std::get< 3 >(e1);
100 double p2xz = std::get< 2 >(e2);
101 double p2yz = std::get< 3 >(e2);
102 double I1 = std::get< 1 >(e1);
103 double I2 = std::get< 1 >(e2);
104 if (std::max(p1xz, p1yz) == std::max(p2xz, p2yz)) {
107 return std::max(p1xz, p1yz) > std::max(p2xz, p2yz);
122 std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >*,
153 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sep_set,
162 for (
const Edge& edge: edges) {
165 double Ixy = I.
score(x, y);
191 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sep_set,
197 std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >*,
202 Size steps_iter = _rank.size();
205 while (_rank.top().second > 0.5) {
208 const NodeId x = std::get< 0 >(*(best.first));
209 const NodeId y = std::get< 1 >(*(best.first));
210 const NodeId z = std::get< 2 >(*(best.first));
211 std::vector< NodeId > ui = std::move(std::get< 3 >(*(best.first)));
214 const double Ixy_ui = I.
score(x, y, ui);
217 sep_set.insert(std::make_pair(x, y), std::move(ui));
248 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
250 std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > >
252 Size steps_orient = triples.size();
260 if (graph.
existsEdge(iter.key().first, iter.key().second)
261 && iter.val() ==
'>') {
263 graph.
addArc(iter.key().first, iter.key().second);
272 while (i < triples.size()) {
274 std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > triple =
277 x = std::get< 0 >(*triple.first);
278 y = std::get< 1 >(*triple.first);
279 z = std::get< 2 >(*triple.first);
281 std::vector< NodeId > ui;
282 std::pair< NodeId, NodeId > key = {x, y};
283 std::pair< NodeId, NodeId > rev_key = {y, x};
284 if (sep_set.exists(key)) {
286 }
else if (sep_set.exists(rev_key)) {
287 ui = sep_set[rev_key];
289 double Ixyz_ui = triple.second;
294 if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
436 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
438 std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > >
440 Size steps_orient = triples.size();
448 while (i < triples.size()) {
450 std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > triple =
453 x = std::get< 0 >(*triple.first);
454 y = std::get< 1 >(*triple.first);
455 z = std::get< 2 >(*triple.first);
457 std::vector< NodeId > ui;
458 std::pair< NodeId, NodeId > key = {x, y};
459 std::pair< NodeId, NodeId > rev_key = {y, x};
460 if (sep_set.exists(key)) {
462 }
else if (sep_set.exists(rev_key)) {
463 ui = sep_set[rev_key];
465 double Ixyz_ui = triple.second;
469 if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
587 const HashTable< std::pair< NodeId, NodeId >,
588 std::vector< NodeId > >& sep_set) {
596 for (
auto iter = marks.
begin(); iter != marks.
end(); ++iter) {
597 if (graph.
existsEdge(iter.key().first, iter.key().second)
598 && iter.val() ==
'>') {
600 graph.
addArc(iter.key().first, iter.key().second);
604 std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId >*,
610 Size steps_orient = proba_triples.size();
613 std::tuple< std::tuple< NodeId, NodeId, NodeId >*, double, double,
double >
615 if (steps_orient > 0) { best = proba_triples[0]; }
617 while (!proba_triples.empty()
618 && std::max(std::get< 2 >(best), std::get< 3 >(best)) > 0.5) {
620 x = std::get< 0 >(*std::get< 0 >(best));
621 y = std::get< 1 >(*std::get< 0 >(best));
622 z = std::get< 2 >(*std::get< 0 >(best));
624 const double i3 = std::get< 1 >(best);
628 if (marks[{x, z}] ==
'o' && marks[{y, z}] ==
'o') {
657 }
else if (marks[{x, z}] ==
'>' && marks[{y, z}] ==
'o') {
672 }
else if (marks[{y, z}] ==
'>' && marks[{x, z}] ==
'o') {
691 if (marks[{x, z}] ==
'>' && marks[{y, z}] ==
'o' 692 && marks[{z, y}] !=
'-') {
709 }
else if (marks[{y, z}] ==
'>' && marks[{x, z}] ==
'o' 710 && marks[{z, x}] !=
'-') {
730 delete std::get< 0 >(best);
731 proba_triples.erase(proba_triples.begin());
735 best = proba_triples[0];
752 graph.
addArc(iter->head(), iter->tail());
754 *iter =
Arc(iter->head(), iter->tail());
767 const std::vector< NodeId >& ui,
779 const double Ixy_ui = I.
score(x, y, ui);
781 for (
const NodeId z: graph) {
783 if (z != x && z != y && std::find(ui.begin(), ui.end(), z) == ui.end()) {
788 const double Ixyz_ui = I.
score(x, y, z, ui);
789 double calc_expo1 = -Ixyz_ui *
M_LN2;
793 }
else if (calc_expo1 < -
__maxLog) {
796 Pnv = 1 / (1 + std::exp(calc_expo1));
800 const double Ixz_ui = I.
score(x, z, ui);
801 const double Iyz_ui = I.
score(y, z, ui);
803 calc_expo1 = -(Ixz_ui - Ixy_ui) * M_LN2;
804 double calc_expo2 = -(Iyz_ui - Ixy_ui) * M_LN2;
816 expo1 = std::exp(calc_expo1);
821 expo2 = std::exp(calc_expo2);
823 Pb = 1 / (1 + expo1 + expo2);
827 const double min_pnv_pb = std::min(Pnv, Pb);
828 if (min_pnv_pb > maxP) {
835 std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >*,
838 auto tup =
new std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >{
847 std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > >
851 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
853 std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > >
856 for (
NodeId x: graph.neighbours(z)) {
857 for (
NodeId y: graph.neighbours(z)) {
858 if (y < x && !graph.existsEdge(x, y)) {
859 std::vector< NodeId > ui;
860 std::pair< NodeId, NodeId > key = {x, y};
861 std::pair< NodeId, NodeId > rev_key = {y, x};
862 if (sep_set.exists(key)) {
864 }
else if (sep_set.exists(rev_key)) {
865 ui = sep_set[rev_key];
868 const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
869 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
871 double Ixyz_ui = I.
score(x, y, z, ui);
872 std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > triple;
873 auto tup =
new std::tuple< NodeId, NodeId, NodeId >{x, y, z};
875 triple.second = Ixyz_ui;
876 triples.push_back(triple);
889 tuple< std::tuple< NodeId, NodeId, NodeId >*,
double, double,
double > >
893 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
895 HashTable< std::pair< NodeId, NodeId >,
char >& marks) {
896 std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId >*,
902 for (
NodeId x: graph.neighbours(z)) {
903 for (
NodeId y: graph.neighbours(z)) {
904 if (y < x && !graph.existsEdge(x, y)) {
905 std::vector< NodeId > ui;
906 std::pair< NodeId, NodeId > key = {x, y};
907 std::pair< NodeId, NodeId > rev_key = {y, x};
908 if (sep_set.exists(key)) {
910 }
else if (sep_set.exists(rev_key)) {
911 ui = sep_set[rev_key];
914 const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
915 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
917 const double Ixyz_ui = I.
score(x, y, z, ui);
918 auto tup =
new std::tuple< NodeId, NodeId, NodeId >{x, y, z};
919 std::tuple< std::tuple< NodeId, NodeId, NodeId >*,
923 triple{tup, Ixyz_ui, 0.5, 0.5};
924 triples.push_back(triple);
925 if (!marks.exists({x, z})) { marks.insert({x, z},
'o'); }
926 if (!marks.exists({z, x})) { marks.insert({z, x},
'o'); }
927 if (!marks.exists({y, z})) { marks.insert({y, z},
'o'); }
928 if (!marks.exists({z, y})) { marks.insert({z, y},
'o'); }
941 tuple< std::tuple< NodeId, NodeId, NodeId >*, double, double,
double > >
944 std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId >*,
947 double > > proba_triples) {
948 for (
auto& triple: proba_triples) {
950 x = std::get< 0 >(*std::get< 0 >(triple));
951 y = std::get< 1 >(*std::get< 0 >(triple));
952 z = std::get< 2 >(*std::get< 0 >(triple));
953 const double Ixyz = std::get< 1 >(triple);
954 double Pxz = std::get< 2 >(triple);
955 double Pyz = std::get< 3 >(triple);
958 const double expo = std::exp(Ixyz);
959 const double P0 = (1 + expo) / (1 + 3 * expo);
961 if (Pxz == Pyz && Pyz == 0.5) {
962 std::get< 2 >(triple) = P0;
963 std::get< 3 >(triple) = P0;
965 if (graph.
existsArc(x, z) && Pxz >= P0) {
966 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
967 }
else if (graph.
existsArc(y, z) && Pyz >= P0) {
968 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
972 const double expo = std::exp(-Ixyz);
973 if (graph.
existsArc(x, z) && Pxz >= 0.5) {
974 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
975 }
else if (graph.
existsArc(y, z) && Pyz >= 0.5) {
976 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
981 return proba_triples;
1007 for (
auto node: essentialGraph) {
1010 for (
const Arc& arc: essentialGraph.arcs()) {
1011 dag.
addArc(arc.tail(), arc.head());
1019 const auto neighbours = graph.
neighbours(node);
1020 for (
auto& neighbour: neighbours) {
1028 graph.
addArc(node, neighbour);
1035 graph.
addArc(neighbour, node);
1039 graph.
addArc(node, neighbour);
1053 template <
typename GUM_SCALAR,
1054 typename GRAPH_CHANGES_SELECTOR,
1055 typename PARAM_ESTIMATOR >
1057 PARAM_ESTIMATOR& estimator,
1067 HashTable< std::pair< NodeId, NodeId >,
char > constraints) {
1085 while (!nodeFIFO.
empty()) {
1086 current = nodeFIFO.
front();
1091 for (
const auto new_one: graph.
parents(current)) {
1092 if (mark.
exists(new_one))
1099 mark.
insert(new_one, current);
1101 if (new_one == n1) {
return true; }
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
iterator begin()
Returns an unsafe iterator pointing to the beginning of the hashtable.
void _iteration(CorrectedMutualInformation<> &I, MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set, Heap< std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > *, double >, GreaterPairOn2nd > &_rank)
Iteration phase.
bool empty() const noexcept
Returns a boolean indicating whether the chained list is empty.
Class representing a Bayesian Network.
ArcProperty< double > __arc_probas
Storing the propabilities for each arc set in the graph.
void _findBestContributor(NodeId x, NodeId y, const std::vector< NodeId > &ui, const MixedGraph &graph, CorrectedMutualInformation<> &I, Heap< std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > *, double >, GreaterPairOn2nd > &_rank)
finds the best contributor node for a pair given a conditioning set
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
double step() const
Returns the delta time between now and the last reset() call (or the constructor).
Signaler3< Size, double, double > onProgress
Progression, error and time.
const iterator & end() noexcept
Returns the unsafe iterator pointing to the end of the hashtable.
bool empty() const noexcept
Indicates whether the set is the empty set.
void set3off2Behaviour()
Sets the orientation phase to follow the one of the 3off2 algorithm.
MixedGraph learnMixedStructure(CorrectedMutualInformation<> &I, MixedGraph graph)
learns the structure of an Essential Graph
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
Miic & operator=(const Miic &from)
copy operator
void addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints)
Set a ensemble of constraints for the orientation phase.
void _orientation_latents(CorrectedMutualInformation<> &I, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set)
Modified version of the orientation phase that tries to propagate orientations from both orientations...
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
void _orientation_miic(CorrectedMutualInformation<> &I, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set)
Orientation phase from the MIIC algorithm, returns a mixed graph that may contain circles...
Generic doubly linked lists.
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
const std::vector< Arc > latentVariables() const
get the list of arcs hiding latent variables
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
bool __usemiic
wether to use the miic algorithm or not
void popFront()
Removes the first element of a List, if any.
void setMiicBehaviour()
Sets the orientation phase to follow the one of the MIIC algorithm.
void _propagatesHead(MixedGraph &graph, NodeId node)
Propagates the orientation from a node to its neighbours.
The class for generic Hash Tables.
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
const NodeSet & neighbours(const NodeId id) const
returns the set of edges adjacent to a given node
bool operator()(const std::pair< std::tuple< NodeId, NodeId, NodeId > *, double > &e1, const std::pair< std::tuple< NodeId, NodeId, NodeId > *, double > &e2) const
void reset()
Reset the timer.
const EdgeSet & edges() const
returns the set of edges stored within the EdgeGraphPart
HashTable< std::pair< NodeId, NodeId >, char > __initial_marks
Initial marks for the orientation phase, used to convey constraints.
std::vector< Arc > __latent_couples
an empty vector of arcs
bool existsEdge(const Edge &edge) const
indicates whether a given edge exists
The base class for all directed edgesThis class is used as a basis for manipulating all directed edge...
std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId > *, double, double, double > > _getUnshieldedTriplesMIIC(const MixedGraph &graph, CorrectedMutualInformation<> &I, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set, HashTable< std::pair< NodeId, NodeId >, char > &marks)
gets the list of unshielded triples in the graph in decreasing value of |I'(x, y, z|{ui})|...
const NodeSet & parents(const NodeId id) const
returns the set of nodes with arc ingoing to a given node
Val & pushBack(const Val &val)
Inserts a new element (a copy) at the end of the chained list.
Size _current_step
The current step.
const std::vector< NodeId > directedPath(const NodeId node1, const NodeId node2) const
returns a directed path from node1 to node2 belonging to the set of arcs
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
const Sequence< NodeId > & topologicalOrder(bool clear=true) const
The topological order stays the same as long as no variable or arcs are added or erased src the topol...
void _initiation(CorrectedMutualInformation<> &I, MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set, Heap< std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > *, double >, GreaterPairOn2nd > &_rank)
Initiation phase.
std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId > *, double > > _getUnshieldedTriples(const MixedGraph &graph, CorrectedMutualInformation<> &I, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set)
gets the list of unshielded triples in the graph in decreasing value of |I'(x, y, z|{ui})| ...
Val & front() const
Returns a reference to first element of a list, if any.
The base class for all undirected edges.
Miic()
default constructor
std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId > *, double, double, double > > _updateProbaTriples(const MixedGraph &graph, std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId > *, double, double, double > > proba_triples)
Gets the orientation probabilities like MIIC for the orientation phase.
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
bool operator()(const std::tuple< std::tuple< NodeId, NodeId, NodeId > *, double, double, double > &e1, const std::tuple< std::tuple< NodeId, NodeId, NodeId > *, double, double, double > &e2) const
int __maxLog
Fixes the maximum log that we accept in exponential computations.
The miic learning algorithm.
const bool __existsDirectedPath(const MixedGraph &graph, const NodeId n1, const NodeId n2) const
checks for directed paths in a graph, consider double arcs like edges
A class that, given a structure and a parameter estimator returns a full Bayes net.
bool operator()(const std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > *, double > &e1, const std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > *, double > &e2) const
Heap data structureThis structure is a basic heap data structure, i.e., it is a container in which el...
std::size_t Size
In aGrUM, hashed values are unsigned long int.
const std::vector< NodeId > __empty_set
an empty conditioning set
virtual void eraseEdge(const Edge &edge)
removes an edge from the EdgeGraphPart
Size size() const noexcept
Returns the number of elements in the set.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
#define GUM_EMIT3(signal, arg1, arg2, arg3)
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
bool existsArc(const Arc &arc) const
indicates whether a given arc exists
Copyright 2005-2020 Pierre-Henri WUILLEMIN () et Christophe GONZALES () info_at_agrum_dot_org.
Size NodeId
Type for node ids.
DAG learnStructure(CorrectedMutualInformation<> &I, MixedGraph graph)
learns the structure of an Bayesian network, ie a DAG, by first learning an Essential graph and then ...
Base class for mixed graphs.
void _orientation_3off2(CorrectedMutualInformation<> &I, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set)
Orientation phase from the 3off2 algorithm, returns a CPDAG.