64 ApproximationScheme::operator=(from);
70 ApproximationScheme::operator=(std::move(from));
82 return e1.second > e2.second;
86 const std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double >& e1,
87 const std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double >& e2)
89 return std::abs(e1.second) > std::abs(e2.second);
94 tuple< std::tuple< NodeId, NodeId, NodeId >*,
double,
double,
double >&
97 tuple< std::tuple< NodeId, NodeId, NodeId >*,
double,
double,
double >&
99 double p1xz = std::get< 2 >(e1);
100 double p1yz = std::get< 3 >(e1);
101 double p2xz = std::get< 2 >(e2);
102 double p2yz = std::get< 3 >(e2);
103 double I1 = std::get< 1 >(e1);
104 double I2 = std::get< 1 >(e2);
105 if (std::max(p1xz, p1yz) == std::max(p2xz, p2yz)) {
108 return std::max(p1xz, p1yz) > std::max(p2xz, p2yz);
123 std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >*,
154 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sep_set,
163 for (
const Edge& edge : edges) {
166 double Ixy = I.
score(x, y);
192 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sep_set,
198 std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >*,
203 Size steps_iter = _rank.size();
206 while (_rank.top().second > 0.5) {
209 const NodeId x = std::get< 0 >(*(best.first));
210 const NodeId y = std::get< 1 >(*(best.first));
211 const NodeId z = std::get< 2 >(*(best.first));
212 std::vector< NodeId > ui = std::move(std::get< 3 >(*(best.first)));
215 const double Ixy_ui = I.
score(x, y, ui);
218 sep_set.insert(std::make_pair(x, y), std::move(ui));
249 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
251 std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > >
253 Size steps_orient = triples.size();
261 if (graph.
existsEdge(iter.key().first, iter.key().second)
262 && iter.val() ==
'>') {
264 graph.
addArc(iter.key().first, iter.key().second);
273 while (i < triples.size()) {
275 std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > triple =
278 x = std::get< 0 >(*triple.first);
279 y = std::get< 1 >(*triple.first);
280 z = std::get< 2 >(*triple.first);
282 std::vector< NodeId > ui;
283 std::pair< NodeId, NodeId > key = {x, y};
284 std::pair< NodeId, NodeId > rev_key = {y, x};
285 if (sep_set.exists(key)) {
287 }
else if (sep_set.exists(rev_key)) {
288 ui = sep_set[rev_key];
290 double Ixyz_ui = triple.second;
295 if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
437 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
439 std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > >
441 Size steps_orient = triples.size();
449 while (i < triples.size()) {
451 std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > triple =
454 x = std::get< 0 >(*triple.first);
455 y = std::get< 1 >(*triple.first);
456 z = std::get< 2 >(*triple.first);
458 std::vector< NodeId > ui;
459 std::pair< NodeId, NodeId > key = {x, y};
460 std::pair< NodeId, NodeId > rev_key = {y, x};
461 if (sep_set.exists(key)) {
463 }
else if (sep_set.exists(rev_key)) {
464 ui = sep_set[rev_key];
466 double Ixyz_ui = triple.second;
470 if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
588 const HashTable< std::pair< NodeId, NodeId >,
589 std::vector< NodeId > >& sep_set) {
597 for (
auto iter = marks.
begin(); iter != marks.
end(); ++iter) {
598 if (graph.
existsEdge(iter.key().first, iter.key().second)
599 && iter.val() ==
'>') {
601 graph.
addArc(iter.key().first, iter.key().second);
605 std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId >*,
611 Size steps_orient = proba_triples.size();
614 std::tuple< std::tuple< NodeId, NodeId, NodeId >*, double, double,
double >
616 if (steps_orient > 0) { best = proba_triples[0]; }
618 while (!proba_triples.empty()
619 && std::max(std::get< 2 >(best), std::get< 3 >(best)) > 0.5) {
621 x = std::get< 0 >(*std::get< 0 >(best));
622 y = std::get< 1 >(*std::get< 0 >(best));
623 z = std::get< 2 >(*std::get< 0 >(best));
625 const double i3 = std::get< 1 >(best);
629 if (marks[{x, z}] ==
'o' && marks[{y, z}] ==
'o') {
658 }
else if (marks[{x, z}] ==
'>' && marks[{y, z}] ==
'o') {
673 }
else if (marks[{y, z}] ==
'>' && marks[{x, z}] ==
'o') {
692 if (marks[{x, z}] ==
'>' && marks[{y, z}] ==
'o' 693 && marks[{z, y}] !=
'-') {
710 }
else if (marks[{y, z}] ==
'>' && marks[{x, z}] ==
'o' 711 && marks[{z, x}] !=
'-') {
731 delete std::get< 0 >(best);
732 proba_triples.erase(proba_triples.begin());
736 best = proba_triples[0];
753 graph.
addArc(iter->head(), iter->tail());
755 *iter =
Arc(iter->head(), iter->tail());
768 const std::vector< NodeId >& ui,
780 const double Ixy_ui = I.
score(x, y, ui);
782 for (
const NodeId z : graph) {
784 if (z != x && z != y && std::find(ui.begin(), ui.end(), z) == ui.end()) {
789 const double Ixyz_ui = I.
score(x, y, z, ui);
790 double calc_expo1 = -Ixyz_ui *
M_LN2;
794 }
else if (calc_expo1 < -
__maxLog) {
797 Pnv = 1 / (1 + std::exp(calc_expo1));
801 const double Ixz_ui = I.
score(x, z, ui);
802 const double Iyz_ui = I.
score(y, z, ui);
804 calc_expo1 = -(Ixz_ui - Ixy_ui) * M_LN2;
805 double calc_expo2 = -(Iyz_ui - Ixy_ui) * M_LN2;
817 expo1 = std::exp(calc_expo1);
822 expo2 = std::exp(calc_expo2);
824 Pb = 1 / (1 + expo1 + expo2);
828 const double min_pnv_pb = std::min(Pnv, Pb);
829 if (min_pnv_pb > maxP) {
836 std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >*,
839 auto tup =
new std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > >{
848 std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > >
852 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
854 std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > >
857 for (
NodeId x : graph.neighbours(z)) {
858 for (
NodeId y : graph.neighbours(z)) {
859 if (y < x && !graph.existsEdge(x, y)) {
860 std::vector< NodeId > ui;
861 std::pair< NodeId, NodeId > key = {x, y};
862 std::pair< NodeId, NodeId > rev_key = {y, x};
863 if (sep_set.exists(key)) {
865 }
else if (sep_set.exists(rev_key)) {
866 ui = sep_set[rev_key];
869 const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
870 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
872 double Ixyz_ui = I.
score(x, y, z, ui);
873 std::pair< std::tuple< NodeId, NodeId, NodeId >*,
double > triple;
874 auto tup =
new std::tuple< NodeId, NodeId, NodeId >{x, y, z};
876 triple.second = Ixyz_ui;
877 triples.push_back(triple);
890 tuple< std::tuple< NodeId, NodeId, NodeId >*,
double, double,
double > >
894 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >&
896 HashTable< std::pair< NodeId, NodeId >,
char >& marks) {
897 std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId >*,
903 for (
NodeId x : graph.neighbours(z)) {
904 for (
NodeId y : graph.neighbours(z)) {
905 if (y < x && !graph.existsEdge(x, y)) {
906 std::vector< NodeId > ui;
907 std::pair< NodeId, NodeId > key = {x, y};
908 std::pair< NodeId, NodeId > rev_key = {y, x};
909 if (sep_set.exists(key)) {
911 }
else if (sep_set.exists(rev_key)) {
912 ui = sep_set[rev_key];
915 const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
916 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
918 const double Ixyz_ui = I.
score(x, y, z, ui);
919 auto tup =
new std::tuple< NodeId, NodeId, NodeId >{x, y, z};
920 std::tuple< std::tuple< NodeId, NodeId, NodeId >*,
924 triple{tup, Ixyz_ui, 0.5, 0.5};
925 triples.push_back(triple);
926 if (!marks.exists({x, z})) { marks.insert({x, z},
'o'); }
927 if (!marks.exists({z, x})) { marks.insert({z, x},
'o'); }
928 if (!marks.exists({y, z})) { marks.insert({y, z},
'o'); }
929 if (!marks.exists({z, y})) { marks.insert({z, y},
'o'); }
942 tuple< std::tuple< NodeId, NodeId, NodeId >*, double, double,
double > >
945 std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId >*,
948 double > > proba_triples) {
949 for (
auto& triple : proba_triples) {
951 x = std::get< 0 >(*std::get< 0 >(triple));
952 y = std::get< 1 >(*std::get< 0 >(triple));
953 z = std::get< 2 >(*std::get< 0 >(triple));
954 const double Ixyz = std::get< 1 >(triple);
955 double Pxz = std::get< 2 >(triple);
956 double Pyz = std::get< 3 >(triple);
959 const double expo = std::exp(Ixyz);
960 const double P0 = (1 + expo) / (1 + 3 * expo);
962 if (Pxz == Pyz && Pyz == 0.5) {
963 std::get< 2 >(triple) = P0;
964 std::get< 3 >(triple) = P0;
966 if (graph.
existsArc(x, z) && Pxz >= P0) {
967 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
968 }
else if (graph.
existsArc(y, z) && Pyz >= P0) {
969 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
973 const double expo = std::exp(-Ixyz *
__N);
974 if (graph.
existsArc(x, z) && Pxz >= 0.5) {
975 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
976 }
else if (graph.
existsArc(y, z) && Pyz >= 0.5) {
977 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
982 return proba_triples;
1008 for (
auto node : essentialGraph) {
1011 for (
const Arc& arc : essentialGraph.arcs()) {
1012 dag.
addArc(arc.tail(), arc.head());
1020 const auto neighbours = graph.
neighbours(node);
1021 for (
auto& neighbour : neighbours) {
1029 graph.
addArc(node, neighbour);
1036 graph.
addArc(neighbour, node);
1040 graph.
addArc(node, neighbour);
1054 template <
typename GUM_SCALAR,
1055 typename GRAPH_CHANGES_SELECTOR,
1056 typename PARAM_ESTIMATOR >
1058 PARAM_ESTIMATOR& estimator,
1068 HashTable< std::pair< NodeId, NodeId >,
char > constraints) {
1086 while (!nodeFIFO.
empty()) {
1087 current = nodeFIFO.
front();
1092 for (
const auto new_one : graph.
parents(current)) {
1093 if (mark.
exists(new_one))
1100 mark.
insert(new_one, current);
1102 if (new_one == n1) {
return true; }
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
iterator begin()
Returns an unsafe iterator pointing to the beginning of the hashtable.
void _iteration(CorrectedMutualInformation<> &I, MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set, Heap< std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > *, double >, GreaterPairOn2nd > &_rank)
Iteration phase.
bool empty() const noexcept
Returns a boolean indicating whether the chained list is empty.
Class representing a Bayesian Network.
ArcProperty< double > __arc_probas
Storing the propabilities for each arc set in the graph.
void _findBestContributor(NodeId x, NodeId y, const std::vector< NodeId > &ui, const MixedGraph &graph, CorrectedMutualInformation<> &I, Heap< std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > *, double >, GreaterPairOn2nd > &_rank)
finds the best contributor node for a pair given a conditioning set
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
double step() const
Returns the delta time between now and the last reset() call (or the constructor).
Signaler3< Size, double, double > onProgress
Progression, error and time.
const iterator & end() noexcept
Returns the unsafe iterator pointing to the end of the hashtable.
bool empty() const noexcept
Indicates whether the set is the empty set.
void set3off2Behaviour()
Sets the orientation phase to follow the one of the 3off2 algorithm.
MixedGraph learnMixedStructure(CorrectedMutualInformation<> &I, MixedGraph graph)
learns the structure of an Essential Graph
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
Miic & operator=(const Miic &from)
copy operator
void addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints)
Set a ensemble of constraints for the orientation phase.
void _orientation_latents(CorrectedMutualInformation<> &I, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set)
Modified version of the orientation phase that tries to propagate orientations from both orientations...
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
void _orientation_miic(CorrectedMutualInformation<> &I, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set)
Orientation phase from the MIIC algorithm, returns a mixed graph that may contain circles...
Generic doubly linked lists.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const std::vector< Arc > latentVariables() const
get the list of arcs hiding latent variables
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
bool __usemiic
wether to use the miic algorithm or not
void popFront()
Removes the first element of a List, if any.
void setMiicBehaviour()
Sets the orientation phase to follow the one of the MIIC algorithm.
void _propagatesHead(MixedGraph &graph, NodeId node)
Propagates the orientation from a node to its neighbours.
The class for generic Hash Tables.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
const NodeSet & neighbours(const NodeId id) const
returns the set of edges adjacent to a given node
bool operator()(const std::pair< std::tuple< NodeId, NodeId, NodeId > *, double > &e1, const std::pair< std::tuple< NodeId, NodeId, NodeId > *, double > &e2) const
void reset()
Reset the timer.
const EdgeSet & edges() const
returns the set of edges stored within the EdgeGraphPart
HashTable< std::pair< NodeId, NodeId >, char > __initial_marks
Initial marks for the orientation phase, used to convey constraints.
std::vector< Arc > __latent_couples
an empty vector of arcs
bool existsEdge(const Edge &edge) const
indicates whether a given edge exists
The base class for all directed edgesThis class is used as a basis for manipulating all directed edge...
std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId > *, double, double, double > > _getUnshieldedTriplesMIIC(const MixedGraph &graph, CorrectedMutualInformation<> &I, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set, HashTable< std::pair< NodeId, NodeId >, char > &marks)
gets the list of unshielded triples in the graph in decreasing value of |I'(x, y, z|{ui})|...
const NodeSet & parents(const NodeId id) const
returns the set of nodes with arc ingoing to a given node
Val & pushBack(const Val &val)
Inserts a new element (a copy) at the end of the chained list.
Size _current_step
The current step.
const std::vector< NodeId > directedPath(const NodeId node1, const NodeId node2) const
returns a directed path from node1 to node2 belonging to the set of arcs
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
const Sequence< NodeId > & topologicalOrder(bool clear=true) const
The topological order stays the same as long as no variable or arcs are added or erased src the topol...
void _initiation(CorrectedMutualInformation<> &I, MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set, Heap< std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > *, double >, GreaterPairOn2nd > &_rank)
Initiation phase.
std::vector< std::pair< std::tuple< NodeId, NodeId, NodeId > *, double > > _getUnshieldedTriples(const MixedGraph &graph, CorrectedMutualInformation<> &I, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set)
gets the list of unshielded triples in the graph in decreasing value of |I'(x, y, z|{ui})| ...
Val & front() const
Returns a reference to first element of a list, if any.
The base class for all undirected edges.
Miic()
default constructor
std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId > *, double, double, double > > _updateProbaTriples(const MixedGraph &graph, std::vector< std::tuple< std::tuple< NodeId, NodeId, NodeId > *, double, double, double > > proba_triples)
Gets the orientation probabilities like MIIC for the orientation phase.
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
bool operator()(const std::tuple< std::tuple< NodeId, NodeId, NodeId > *, double, double, double > &e1, const std::tuple< std::tuple< NodeId, NodeId, NodeId > *, double, double, double > &e2) const
int __maxLog
Fixes the maximum log that we accept in exponential computations.
The miic learning algorithm.
const bool __existsDirectedPath(const MixedGraph &graph, const NodeId n1, const NodeId n2) const
checks for directed paths in a graph, consider double arcs like edges
A class that, given a structure and a parameter estimator returns a full Bayes net.
bool operator()(const std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > *, double > &e1, const std::pair< std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > *, double > &e2) const
Heap data structureThis structure is a basic heap data structure, i.e., it is a container in which el...
std::size_t Size
In aGrUM, hashed values are unsigned long int.
const std::vector< NodeId > __empty_set
an empty conditioning set
virtual void eraseEdge(const Edge &edge)
removes an edge from the EdgeGraphPart
Size size() const noexcept
Returns the number of elements in the set.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
#define GUM_EMIT3(signal, arg1, arg2, arg3)
Size __N
size of the database
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
bool existsArc(const Arc &arc) const
indicates whether a given arc exists
Copyright 2005-2019 Pierre-Henri WUILLEMIN et Christophe GONZALES (LIP6) {prenom.nom}_at_lip6.fr.
Size NodeId
Type for node ids.
DAG learnStructure(CorrectedMutualInformation<> &I, MixedGraph graph)
learns the structure of an Bayesian network, ie a DAG, by first learning an Essential graph and then ...
Base class for mixed graphs.
void _orientation_3off2(CorrectedMutualInformation<> &I, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sep_set)
Orientation phase from the 3off2 algorithm, returns a CPDAG.