# Copyright (c) 2021. Lucas G. S. Jeub
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import typing as _t
import networkx as nx
import torch
import torch_scatter as ts
import torch_geometric as tg
from .graph import Graph
[docs]
class TGraph(Graph):
"""Wrapper class for pytorch-geometric edge_index providing fast adjacency look-up."""
@staticmethod
def _convert_input(input):
if input is None:
return None
else:
return torch.as_tensor(input)
[docs]
def __init__(self, *args, ensure_sorted=False, **kwargs):
super().__init__(*args, **kwargs)
if self.num_nodes is None:
self.num_nodes = int(torch.max(self.edge_index)+1) #: number of nodes
if ensure_sorted:
index = torch.argsort(self.edge_index[0]*self.num_nodes+self.edge_index[1])
self.edge_index = self.edge_index[:, index]
if self.edge_attr is not None:
self.edge_attr = self.edge_attr[index]
if self.adj_index is None:
self.degree = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device) #: tensor of node degrees
self.degree.index_add_(0, self.edge_index[0],
torch.ones(1, dtype=torch.long, device=self.device).expand(self.num_edges)) # use expand to avoid actually allocating large array
self.adj_index = torch.zeros(self.num_nodes + 1, dtype=torch.long) #: adjacency index such that edges starting at node ``i`` are given by ``edge_index[:, adj_index[i]:adj_index[i+1]]``
self.adj_index[1:] = torch.cumsum(self.degree, 0)
else:
self.degree = self.adj_index[1:] - self.adj_index[:-1]
if self.weighted:
self.weights = self.edge_attr
self.strength = torch.zeros(self.num_nodes, device=self.device, dtype=self.weights.dtype) #: tensor of node strength
self.strength.index_add_(0, self.edge_index[0], self.weights)
else:
# use expand to avoid actually allocating large array
self.weights = torch.ones(1, device=self.device).expand(self.num_edges)
self.strength = self.degree
if self.undir is None:
index = torch.argsort(self.edge_index[1]*self.num_nodes+self.edge_index[0])
self.undir = torch.equal(self.edge_index, self.edge_index[:, index].flip((0,)))
if self.weighted:
self.undir = self.undir and torch.equal(self.weights, self.weights[index])
@property
def device(self):
"""device holding graph data"""
return self.edge_index.device
[docs]
def edges(self):
"""
return list of edges where each edge is a tuple ``(source, target)``
"""
return ((self.edge_index[0, e].item(), self.edge_index[1, e].item()) for e in range(self.num_edges))
[docs]
def edges_weighted(self):
"""
return list of edges where each edge is a tuple ``(source, target, weight)``
"""
return ((self.edge_index[0, e].item(), self.edge_index[1, e].item(), self.weights[e].cpu().numpy()
if self.weights.ndim > 1 else self.weights[e].item()) for e in range(self.num_edges))
[docs]
def is_edge(self, source, target):
index = torch.bucketize(target, self.edge_index[1, self.adj_index[source]:self.adj_index[source+1]])
if index < self.degree[source] and self.edge_index[1, self.adj_index[source]+index] == target:
return True
else:
return False
[docs]
def neighbourhood(self, nodes: torch.Tensor, hops: int = 1):
"""
find the neighbourhood of a set of source nodes
note that the neighbourhood includes the source nodes themselves
Args:
nodes: indices of source nodes
hops: number of hops for neighbourhood
Returns:
neighbourhood
"""
explore = torch.ones(self.num_nodes, dtype=torch.bool, device=self.device)
explore[nodes] = False
all_nodes = [nodes]
new_nodes = nodes
for _ in range(hops):
new_nodes = torch.cat([self.adj(node) for node in new_nodes])
new_nodes = torch.unique(new_nodes[explore[new_nodes]])
explore[new_nodes] = False
all_nodes.append(new_nodes)
return torch.cat(all_nodes)
[docs]
def subgraph(self, nodes: torch.Tensor, relabel=False, keep_x=True, keep_y=True):
"""
find induced subgraph for a set of nodes
Args:
nodes: node indeces
Returns:
subgraph
"""
index = torch.cat([torch.arange(self.adj_index[node], self.adj_index[node + 1], dtype=torch.long) for node in nodes])
node_mask = torch.zeros(self.num_nodes, dtype=torch.bool, device=self.device)
node_mask[nodes] = True
node_ids = torch.zeros(self.num_nodes, dtype=torch.long, device=self.device)
node_ids[nodes] = torch.arange(len(nodes), device=self.device)
index = index[node_mask[self.edge_index[1][index]]]
edge_attr = self.edge_attr
if relabel:
node_labels = None
else:
node_labels = [self.nodes[n] for n in nodes]
if self.x is not None and keep_x:
x = self.x[nodes, :]
else:
x = None
if self.y is not None and keep_y:
y = self.y[nodes]
else:
y = None
return self.__class__(edge_index=node_ids[self.edge_index[:, index]],
edge_attr=edge_attr[index] if edge_attr is not None else None,
num_nodes=len(nodes),
ensure_sorted=True,
undir=self.undir,
x=x,
y=y,
nodes=node_labels
)
[docs]
def connected_component_ids(self):
"""Find the (weakly)-connected components. Component ids are sorted by size, such that id=0 corresponds
to the largest connected component"""
edge_index = self.edge_index
is_undir = self.undir
last_components = torch.full((self.num_nodes,), self.num_nodes, dtype=torch.long, device=self.device)
components = torch.arange(self.num_nodes, dtype=torch.long, device=self.device)
while not torch.equal(last_components, components):
last_components[:] = components
components = ts.scatter(last_components[edge_index[0]], edge_index[1], out=components, reduce='min')
if not is_undir:
components = ts.scatter(last_components[edge_index[1]], edge_index[0], out=components, reduce='min')
component_id, inverse, component_size = torch.unique(components, return_counts=True, return_inverse=True)
new_id = torch.argsort(component_size, descending=True)
return new_id[inverse]
[docs]
def nodes_in_lcc(self):
"""List all nodes in the largest connected component"""
return torch.nonzero(self.connected_component_ids() == 0).flatten()
[docs]
def to_networkx(self):
"""convert graph to NetworkX format"""
if self.undir:
nxgraph = nx.Graph()
else:
nxgraph = nx.DiGraph()
nxgraph.add_nodes_from(range(self.num_nodes))
if self.x is not None:
for i in range(self.num_nodes):
nxgraph.nodes[i]['x'] = self.x[i, :]
if self.y is not None:
for i in range(self.num_nodes):
nxgraph.nodes[i]['y'] = self.y[i]
if self.weighted:
nxgraph.add_weighted_edges_from(self.edges_weighted())
else:
nxgraph.add_edges_from(self.edges())
return nxgraph
[docs]
def to(self, *args, graph_cls=None, **kwargs):
"""
Convert to different graph type or move to device
Args:
graph_cls: convert to graph class
device: convert to device
Can only specify one argument. If positional, type of move is determined automatically.
"""
if args:
if not (graph_cls is None):
raise ValueError("Both positional and graph_cls keyword argument specified.")
elif len(args) == 1:
arg = args[0]
if isinstance(arg, type) and issubclass(arg, Graph):
graph_cls = arg
if kwargs:
raise ValueError("Cannot specify additional keyword arguments when converting between graph classes.")
if graph_cls is not None:
return super().to(graph_cls)
else:
for key, value in self.__dict__.items():
if isinstance(value, torch.Tensor):
self.__dict__[key] = value.to(*args, **kwargs)
return self
[docs]
def bfs_order(self, start=0):
"""
return nodes in breadth-first-search order
Args:
start: index of starting node (default: 0)
Returns:
tensor of node indeces
"""
bfs_list = torch.full((self.num_nodes,), -1, dtype=torch.long, device=self.device)
not_visited = torch.ones(self.num_nodes, dtype=torch.bool, device=self.device)
bfs_list[0] = start
not_visited[start] = False
append_pointer = 1
i = 0
while append_pointer < self.num_nodes:
node = bfs_list[i]
if node < 0:
node = torch.nonzero(not_visited)[0]
bfs_list[i] = node
not_visited[node] = False
append_pointer += 1
i += 1
new_nodes = self.adj(node)
new_nodes = new_nodes[not_visited[new_nodes]]
number_new_nodes = len(new_nodes)
not_visited[new_nodes] = False
bfs_list[append_pointer:append_pointer+number_new_nodes] = new_nodes
append_pointer += number_new_nodes
return bfs_list
[docs]
def partition_graph(self, partition, self_loops=True):
num_clusters = torch.max(partition) + 1
pe_index = partition[self.edge_index[0]]*num_clusters + partition[self.edge_index[1]]
partition_edges, weights = torch.unique(pe_index, return_counts=True)
partition_edges = torch.stack((partition_edges // num_clusters, partition_edges % num_clusters), dim=0)
if not self_loops:
valid = partition_edges[0] != partition_edges[1]
partition_edges = partition_edges[:, valid]
weights = weights[valid]
return self.__class__(edge_index=partition_edges, edge_attr=weights, num_nodes=num_clusters, undir=self.undir)
[docs]
def sample_negative_edges(self, num_samples):
return tg.utils.negative_sampling(self.edge_index, self.num_nodes, num_samples)
[docs]
def sample_positive_edges(self, num_samples):
index = torch.randint(self.num_edges, (num_samples,), dtype=torch.long)
return self.edge_index[:, index]