"""Dividing input data into overlapping patches"""
from random import choice
from math import ceil
from collections.abc import Iterable
import torch
import numpy as np
from tqdm.auto import tqdm
import numba
from local2global_embedding.clustering import Partition
from local2global_embedding.network import TGraph, NPGraph
from local2global_embedding.network.npgraph import JitGraph
from local2global_embedding.sparsify import resistance_sparsify, relaxed_spanning_tree, edge_sampling_sparsify, \
    hierarchical_sparsify, nearest_neighbor_sparsify, conductance_weighted_graph
[docs]
@numba.njit
def geodesic_expand_overlap(subgraph, seed_mask, min_overlap, target_overlap, reseed_samples=10):
    """
    expand patch
    Args:
        subgraph: graph containing patch nodes and all target nodes for potential expansion
        source_nodes: index of source nodes (initial starting nodes for expansion)
        min_overlap: minimum overlap before stopping expansion
        target_overlap: maximum overlap (if expansion step results in more overlap, the nodes
                        added are sampled at random)
    Returns:
        index tensor of new nodes to add to patch
    """
    if subgraph.num_nodes < min_overlap:
        raise RuntimeError("Minimum overlap > number of nodes")
    mask = ~seed_mask
    new_nodes = np.flatnonzero(seed_mask)
    overlap = new_nodes
    if overlap.size > target_overlap:
        overlap = np.random.choice(overlap, target_overlap, replace=False)
    while overlap.size < min_overlap:
        new_nodes = subgraph.neighbours(new_nodes)
        new_nodes = new_nodes[mask[new_nodes]]
        if not new_nodes.size:
            # no more connected nodes to add so add some remaining nodes by random sampling
            new_nodes = np.flatnonzero(mask)
            if new_nodes.size > reseed_samples:
                new_nodes = np.random.choice(new_nodes, reseed_samples, replace=False)
        if overlap.size + new_nodes.size > target_overlap:
            new_nodes = np.random.choice(new_nodes, target_overlap - overlap.size, replace=False)
        if not new_nodes.size:
            raise RuntimeError("Could not reach minimum overlap.")
        mask[new_nodes] = False
        overlap = np.concatenate((overlap, new_nodes))
    return overlap 
[docs]
def merge_small_clusters(graph: TGraph, partition_tensor: torch.LongTensor, min_size):
    """
    Merge small clusters with adjacent clusters such that all clusters satisfy a minimum size constraint.
    This function iteratively merges the smallest cluster with its neighbouring cluster with the
    maximum normalized cut.
    Args:
        graph: Input graph
        partition_tensor: input partition vector mapping nodes to clusters
        min_size: desired minimum size of clusters
    Returns:
        new partition tensor where small clusters are merged.
    """
    parts = [torch.as_tensor(p, device=graph.device) for p in Partition(partition_tensor)]
    num_parts = len(parts)
    part_degs = torch.tensor([graph.degree[p].sum() for p in parts], device=graph.device)
    sizes = torch.tensor([len(p) for p in parts], dtype=torch.long)
    smallest_id = torch.argmin(sizes)
    while sizes[smallest_id] < min_size:
        out_neighbour_fraction = torch.zeros(num_parts, device=graph.device)
        p = parts[smallest_id]
        for node in p:
            other = partition_tensor[graph.adj(node)]
            out_neighbour_fraction.scatter_add_(0, other, torch.ones(1, device=graph.device).expand(other.shape))
        if out_neighbour_fraction.sum() == 0:
            merge = torch.argsort(sizes)[1]
        else:
            out_neighbour_fraction /= part_degs  # encourage merging with smaller clusters
            out_neighbour_fraction[smallest_id] = 0
            merge = torch.argmax(out_neighbour_fraction)
        if merge > smallest_id:
            new_id = smallest_id
            other = merge
        else:
            new_id = merge
            other = smallest_id
        partition_tensor[parts[other]] = new_id
        sizes[new_id] += sizes[other]
        part_degs[new_id] += part_degs[other]
        parts[new_id] = torch.cat((parts[new_id], parts[other]))
        if other < num_parts - 1:
            partition_tensor[parts[-1]] = other
            sizes[other] = sizes[-1]
            part_degs[other] = part_degs[-1]
            parts[other] = parts[-1]
        num_parts = num_parts - 1
        sizes = sizes[:num_parts]
        part_degs = part_degs[:num_parts]
        parts = parts[:num_parts]
        smallest_id = torch.argmin(sizes)
    return partition_tensor 
[docs]
def create_overlapping_patches(graph, partition_tensor: torch.LongTensor, patch_graph, min_overlap,
                               target_overlap):
    """
    Create overlapping patches from a hard partition of an input graph
    Args:
        graph: input graph
        partition_tensor: partition of input graph
        patch_graph: graph where nodes are clusters of partition and an edge indicates that the corresponding
                     patches in the output should have at least ``min_overlap`` nodes in common
        min_overlap: minimum overlap for connected patches
        target_overlap: maximum overlap during expansion for an edge (additional overlap may
                        result from expansion of other edges)
    Returns:
        list of node-index tensors for patches
    """
    if isinstance(partition_tensor, torch.Tensor):
        partition_tensor = partition_tensor.cpu()
    graph = graph.to(NPGraph)._jitgraph
    patch_graph = patch_graph.to(NPGraph)._jitgraph
    parts = Partition(partition_tensor)
    partition_tensor = partition_tensor.numpy()
    patches = numba.typed.List(np.asanyarray(p) for p in parts)
    for i in tqdm(range(patch_graph.num_nodes), desc='enlarging patch overlaps'):
        part_i = parts[i].numpy()
        part_i.sort()
        patches = _patch_overlaps(i, part_i, partition_tensor, patches, graph, patch_graph, ceil(min_overlap / 2),
                                  int(target_overlap / 2))
    return patches 
@numba.njit
def _patch_overlaps(i, part, partition, patches, graph, patch_graph, min_overlap, target_overlap):
    max_edges = graph.degree[part].sum()
    edge_index = np.empty((2, max_edges), dtype=np.int64)
    adj_index = np.zeros((len(part)+1,), dtype=np.int64)
    part_index = np.full((graph.num_nodes,), -1, dtype=np.int64)
    part_index[part] = np.arange(len(part))
    patch_index = np.full((patch_graph.num_nodes,), -1, dtype=np.int64)
    patch_index[patch_graph.adj(i)] = np.arange(patch_graph.degree[i])
    source_mask = np.zeros((part.size, patch_graph.degree[i]), dtype=np.bool_)  # track source nodes for different patches
    edge_count = 0
    for index in range(len(part)):
        targets = graph.adj(part[index])
        for t in part_index[targets]:
            if t >= 0:
                edge_index[0, edge_count] = index
                edge_index[1, edge_count] = t
                edge_count += 1
        adj_index[index+1] = edge_count
        pi = patch_index[partition[targets]]
        pi = pi[pi >= 0]
        source_mask[index][pi] = True
    edge_index = edge_index[:, :edge_count]
    subgraph = JitGraph(edge_index, len(part), adj_index, None)
    for it, j in enumerate(patch_graph.adj(i)):
        patches[j] = np.concatenate((patches[j],
                                     part[geodesic_expand_overlap(
                                         subgraph,
                                         seed_mask=source_mask[:, it],
                                         min_overlap=min_overlap,
                                         target_overlap=target_overlap)]))
    return patches
[docs]
def create_patch_data(graph: TGraph, partition_tensor, min_overlap, target_overlap,
                      min_patch_size=None, sparsify_method='resistance', target_patch_degree=4, gamma=0, verbose=False):
    """
    Divide data into overlapping patches
    Args:
        graph: input data
        partition_tensor: starting partition for creating patches
        min_overlap: minimum patch overlap for connected patches
        target_overlap: maximum patch overlap during expansion of an edge of the patch graph
        min_patch_size: minimum size of patches
        sparsify_method: method for sparsifying patch graph (one of ``'resistance'``, ``'rmst'``, ``'none'``)
        target_patch_degree: target patch degree for ``sparsify_method='resistance'``
        gamma: ``gamma`` value for use with ``sparsify_method='rmst'``
        verbose: if true, print some info about created patches
    Returns:
        list of patch data, patch graph
    """
    if min_patch_size is None:
        min_patch_size = min_overlap
    if isinstance(partition_tensor, list):
        partition_tensor_0 = partition_tensor[0]
    else:
        partition_tensor = merge_small_clusters(graph, partition_tensor, min_patch_size)
        partition_tensor_0 = partition_tensor
    if verbose:
        print(f"number of patches: {partition_tensor_0.max().item() + 1}")
    pg = graph.partition_graph(partition_tensor_0, self_loops=False).to(TGraph)
    components = pg.connected_component_ids()
    num_components = components.max()+1
    if num_components > 1:
        # connect all components
        edges = torch.empty((2, num_components*(num_components-1)/2), dtype=torch.long)
        comp_lists = [[] for _ in range(num_components)]
        for i, c in enumerate(components):
            comp_lists[c].append(i)
        i = 0
        for c1 in range(num_components):
            for c2 in range(c1+1, num_components):
                p1 = choice(comp_lists[c1])
                p2 = choice(comp_lists[c2])
                edges[:, i] = (p1, p2)
                i += 1
        edge_index = torch.cat((pg.edge_index, edges, edges[::-1, :]))
        weights = torch.cat((pg.edge_attr, torch.ones(2*edges.shape[1], dtype=torch.long)))
        pg = TGraph(edge_index=edge_index, edge_attr=weights, ensure_sorted=True, num_nodes=pg.num_nodes,
                    undir=pg.undir)
        pg = conductance_weighted_graph(pg)
    if sparsify_method == 'resistance':
        if isinstance(partition_tensor, list):
            pg = hierarchical_sparsify(pg, partition_tensor[1:], target_patch_degree, sparsifier=resistance_sparsify)
        else:
            pg = resistance_sparsify(pg, target_mean_degree=target_patch_degree)
    elif sparsify_method == 'rmst':
        pg = relaxed_spanning_tree(pg, maximise=True, gamma=gamma)
    elif sparsify_method == 'sample':
        if isinstance(partition_tensor, list):
            pg = hierarchical_sparsify(pg, partition_tensor[1:], target_patch_degree, sparsifier=edge_sampling_sparsify)
        else:
            pg = edge_sampling_sparsify(pg, target_patch_degree)
    elif sparsify_method == 'neighbors':
        if isinstance(partition_tensor, list):
            pg = hierarchical_sparsify(pg, partition_tensor[1:], target_patch_degree,
                                       sparsifier=nearest_neighbor_sparsify)
        else:
            pg = nearest_neighbor_sparsify(pg, target_patch_degree)
    elif sparsify_method == 'none':
        pass
    else:
        raise ValueError(
            f"Unknown sparsify method '{sparsify_method}', should be one of 'resistance', 'rmst', or 'none'.")
    if verbose:
        print(f"average patch degree: {pg.num_edges / pg.num_nodes}")
    patches = create_overlapping_patches(graph, partition_tensor_0, pg, min_overlap, target_overlap)
    return patches, pg 
[docs]
def rolling_window_graph(n_patches, w):
    """
    Generate patch edges for a rolling window
    Args:
        n_patches: Number of patches
        w: window width (patches connected to the w nearest neighbours on either side)
    """
    if not isinstance(w, Iterable):
        w = range(1, w)
    edges = []
    for i in range(n_patches):
        for wi in w:
            j = i-wi
            if j >= 0 and i != j:
                edges.append((i, j))
        for wi in w:
            j = i + wi
            if j < n_patches and i != j:
                edges.append((i, j))
    return TGraph(edge_index=torch.tensor(edges).T, num_nodes=n_patches, undir=True)