Source code for otx.api.entities.graph

"""This module implements the TrainParameters entity."""
# Copyright (C) 2021-2022 Intel Corporation
# SPDX-License-Identifier: Apache-2.0
#

from typing import Union

import networkx as nx

from otx.api.entities.interfaces.graph_interface import IGraph


[docs] class Graph(IGraph): """The concrete implementation of IGraph. This implementation is using networkx library. Args: directed (bool): set to True if the graph is a directed graph. """ def __init__(self, directed: bool = False): self._graph: Union[nx.Graph, nx.MultiDiGraph] = nx.Graph() if not directed else nx.MultiDiGraph() self.directed = directed
[docs] def get_graph(self) -> Union[nx.Graph, nx.MultiDiGraph]: """Get the underlying NetworkX graph.""" return self._graph
[docs] def set_graph(self, graph: Union[nx.Graph, nx.MultiDiGraph]): """Set the underlying NetworkX graph.""" self._graph = graph
[docs] def add_edge(self, node1, node2, edge_value=None): """Adds edge between node1 and node2.""" # pylint: disable=arguments-differ self._graph.add_edge(node1, node2, value=edge_value)
[docs] def num_nodes(self) -> int: """Returns the number of nodes in the graph.""" return self._graph.number_of_nodes()
[docs] def add_node(self, node): """Adds node to the graph.""" if node not in self._graph.nodes: self._graph.add_node(node)
[docs] def has_edge_between(self, node1, node2): """Returns True if there is an edge between node1 and node2.""" return node1 in self.neighbors(node2)
[docs] def neighbors(self, node): """Returns neighbors of `label`. Note: when `node` does not exist in the graph an empty list is returned """ try: result = list(self._graph.neighbors(node)) except nx.NetworkXError: result = [] return result
[docs] def find_out_edges(self, node): """Returns the edges that have `node` as a destination.""" # pylint: disable=no-member if node not in self._graph.nodes: raise KeyError(f"The node `{node}` is not part of the graph") if isinstance(self._graph, nx.MultiDiGraph): return self._graph.out_edges(node) return []
[docs] def find_in_edges(self, node): """Returns the edges that have `node` as a source.""" # pylint: disable=no-member if node not in self._graph.nodes: raise KeyError(f"The node `{node}` is not part of the graph") if isinstance(self._graph, nx.MultiDiGraph): return self._graph.in_edges(node) return []
[docs] def find_cliques(self): """Returns cliques in the graph.""" return nx.algorithms.clique.find_cliques(self._graph)
@property def nodes(self): """Returns the nodes in the graph.""" return self._graph.nodes @property def edges(self): """Returns all the edges in the graph.""" if isinstance(self._graph, nx.MultiDiGraph): all_edges = self._graph.edges(keys=True, data=True) else: all_edges = self._graph.edges(data=True) return all_edges @property def num_labels(self): """Returns the number of labels in the graph.""" return nx.convert_matrix.to_numpy_matrix(self._graph).shape[0]
[docs] def remove_edges(self, node1, node2): """Removes edges between both the nodes.""" self._graph.remove_edge(node1, node2)
[docs] def remove_node(self, node): """Remove node from graph. Args: node: node to remove """ self._graph.remove_node(node)
[docs] def descendants(self, parent): """Returns descendants. (children and children of children, etc.) of `parent`. """ try: edges = list(nx.edge_dfs(self._graph, parent, orientation="reverse")) except nx.exception.NetworkXError: edges = [] return [edge[0] for edge in edges]
def __eq__(self, other: object) -> bool: """Returns True if the two graphs are equal.""" if isinstance(other, Graph): return ( self.directed == other.directed and self._graph.nodes == other._graph.nodes and self._graph.edges == other._graph.edges ) return False
[docs] class MultiDiGraph(Graph): """Multi Dimensional implementation of a Graph.""" def __init__(self) -> None: super().__init__(directed=True)
[docs] def topological_sort(self): """Returns a generator of nodes in topologically sorted order.""" return nx.topological_sort(self._graph)