"""Implementation of local2global algorithm"""
import scipy as sp
from scipy import sparse as ss
from scipy.linalg import orthogonal_procrustes
import numpy as np
from scipy.sparse.linalg import lsqr, lsmr
from scipy.spatial import procrustes
import copy
from collections import defaultdict
from typing import List, Callable, Any
import networkx as nx
from pathlib import Path
import json
rg = np.random.default_rng()
eps = np.finfo(float).eps
[docs]def seed(new_seed):
"""
Change seed of random number generator.
Args:
new_seed: New seed value
"""
global rg
rg = np.random.default_rng(new_seed)
[docs]def ensure_extension(filename, extension):
"""
check filename for extension and add it if necessary
Args:
filename: input filename
extension: desired extension (including `.`)
Returns:
filename with extension added
Raises:
ValueError: if filename has the wrong extension
"""
filename = Path(filename)
if filename.suffix == "":
filename = filename.with_suffix(extension)
elif filename.suffix != extension:
raise ValueError(f"filename should have extension {extension}, not {filename.suffix}")
return filename
[docs]def procrustes_error(coordinates1, coordinates2):
"""
compute the procrustes alignment error between two sets of coordinates
Args:
coordinates1: First set of coordinates (array-like)
coordinates2: Second set of coordinates (array-like)
Note that the two sets of coordinates need to have the same shape.
"""
return procrustes(coordinates1, coordinates2)[2]
[docs]def orthogonal_MSE_error(rots1, rots2):
"""
Compute the MSE between two sets of orthogonal transformations up to a global transformation
Args:
rots1: First list of orthogonal matrices
rots2: Second list of orthogonal matrices
"""
dim = len(rots1[0])
rots1 = np.asarray(rots1)
rots1 = rots1.transpose((0, 2, 1))
rots2 = np.asarray(rots2)
combined = np.mean(rots1 @ rots2, axis=0)
_, s, _ = sp.linalg.svd(combined)
return 2*(dim - np.sum(s))
def _cov_svd(coordinates1: np.ndarray, coordinates2: np.ndarray):
"""
Compute SVD of covariance matrix between two sets of coordinates
Args:
coordinates1: First set of coordinates (array-like)
coordinates2: Second set of coordinates (array-like)
Note that the two sets of coordinates need to have the same shape.
"""
coordinates1 = coordinates1 - coordinates1.mean(axis=0)
coordinates2 = coordinates2 - coordinates2.mean(axis=0)
cov = coordinates1.T @ coordinates2
return sp.linalg.svd(cov)
[docs]def nearest_orthogonal(mat):
"""
Compute nearest orthogonal matrix to a given input matrix
Args:
mat: input matrix
"""
U, s, Vh = sp.linalg.svd(mat)
return U @ Vh
[docs]def relative_scale(coordinates1, coordinates2, clamp=1e8):
"""
compute relative scale of two sets of coordinates for the same nodes
Args:
coordinates1: First set of coordinates (array-like)
coordinates2: Second set of coordinates (array-like)
Note that the two sets of coordinates need to have the same shape.
"""
scale1 = np.linalg.norm(coordinates1 - np.mean(coordinates1, axis=0))
scale2 = np.linalg.norm(coordinates2 - np.mean(coordinates2, axis=0))
if scale1 > clamp * scale2:
print('extremely large scale clamped')
return clamp
if scale1 * clamp < scale2:
print('extremely small scale clamped')
return 1/clamp
return scale1 / scale2
[docs]class Patch:
"""
Class for patch embedding
"""
index = None
"""mapping of node index to patch coordinate index"""
coordinates = None
"""patch embedding coordinates"""
[docs] def __init__(self, nodes, coordinates):
"""
Initialise a patch from a list of nodes and corresponding coordinates
Args:
nodes: Iterable of integer node indeces for patch
coordinates: Array-like of node coordinates of shape (len(nodes), dim)
"""
self.index = {int(n): i for i, n in enumerate(nodes)}
self.coordinates = np.array(coordinates)
@property
def shape(self):
"""
shape of patch coordinates
(`shape[0]` is the number of nodes in the patch
and `shape[1]` is the embedding dimension)
"""
return self.coordinates.shape
[docs] def get_coordinates(self, nodes):
"""
get coordinates for a list of nodes
Args:
nodes: Iterable of node indeces
"""
return self.coordinates[[self.index[node] for node in nodes], :]
[docs] def get_coordinate(self, node):
"""
get coordinate for a single node
Args:
node: Integer node index
"""
return self.coordinates[self.index[node], :]
def __copy__(self):
"""return a copy of the patch"""
instance = self.__new__(type(self))
instance.index = dict(self.index)
instance.coordinates = np.array(self.coordinates)
return instance
[docs]class AlignmentProblem:
"""
Implements the standard local2global algorithm using an unweighted patch graph
"""
n_nodes = None
"""total number of nodes"""
n_patches = None
"""number of patches"""
dim = None
"""embedding dimension"""
scales = None
"""tracks scale transformations applied to patches (updated by :meth:`scale_patches`)"""
rotations = None
"""tracks orthogonal transformations applied to patches (updated by :meth:`rotate_patches`)"""
shifts = None
"""tracks translation transformations applied to patches (updated by :meth:`scale_patches`,
:meth:`rotate_patches`, and :meth:`translate_patches`)"""
verbose = False
"""print debug output if `True`"""
[docs] def weight(self, i, j):
"""Compute the weighting factor for a pair of patches
Args:
i: First patch index
j: Second patch index
Returns:
1
Override this in subclasses for weighted alignment
"""
return 1
[docs] def __init__(self, patches: List[Patch], patch_edges=None,
min_overlap=None, copy_data=True, self_loops=False, verbose=False):
"""
Initialise the alignment problem with a list of patches
Args:
patches: List of patches to synchronise
patch_edges: if provided, only compute relative transformations for given patch edges (all pairs of patches
with at least ``min_overlap`` points in common are included by default)
min_overlap (int): minimum number of points in the overlap required for two patches to be considered
connected (defaults to `dim+1`) where `dim` is the embedding dimension of the patches
copy_data (bool): if ``True``, input patches are copied (default: ``True``)
self_loops (bool): if ``True``, self-loops from a patch to itself are included in the synchronisation problem
(default: ``False``)
verbose(bool): if True print diagnostic information (default: ``False``)
"""
if copy_data:
self.patches = [copy.copy(patch) for patch in patches]
else:
self.patches = patches
self.verbose = verbose
self.n_nodes = max(max(patch.index.keys()) for patch in self.patches) + 1
self.n_patches = len(self.patches)
self.dim = self.patches[0].shape[1]
self.scales = np.ones(self.n_patches)
self.rotations = np.tile(np.eye(self.dim), (self.n_patches, 1, 1))
self.shifts = np.zeros((self.n_patches, self.dim))
self._aligned_embedding = None
if min_overlap is None:
min_overlap = self.dim + 1
# create an index for the patch membership of each node
self.patch_index = [[] for _ in range(self.n_nodes)]
for i, patch in enumerate(self.patches):
for node in patch.index.keys():
self.patch_index[node].append(i)
# find patch overlaps
self.patch_overlap = defaultdict(list)
for i, patch in enumerate(self.patches):
for node in patch.index:
for j in self.patch_index[node]:
if self_loops or i != j:
self.patch_overlap[i, j].append(node)
# restrict to patch edges if provided
if patch_edges is not None:
self.patch_overlap = {e: self.patch_overlap[e] for e in patch_edges}
# remove small overlaps
keys = list(self.patch_overlap.keys())
for e in keys:
if len(self.patch_overlap[e]) < min_overlap:
if patch_edges is None:
del self.patch_overlap[e]
else:
raise RuntimeError("Patch edges do not satisfy minimum overlap")
# find patch degrees
self.patch_degrees = [0] * self.n_patches
for i, j in self.patch_overlap.keys():
self.patch_degrees[i] += 1
patch_graph = nx.Graph()
patch_graph.add_edges_from(self.patch_overlap.keys())
if nx.number_connected_components(patch_graph) > 1:
raise RuntimeError("patch graph is not connected")
if self.verbose:
print(f'mean patch degree: {np.mean(self.patch_degrees)}')
[docs] def scale_patches(self, scale_factors=None):
"""
Synchronise scales of the embeddings for each patch
Args:
scale_factors: if provided apply the given scales instead of synchronising
"""
if scale_factors is None:
scale_factors = [1 / x for x in self.calc_synchronised_scales()]
for i, scale in enumerate(scale_factors):
self.patches[i].coordinates *= scale
# track transformations
self.scales[i] *= scale
self.shifts[i] *= scale
[docs] def calc_synchronised_scales(self, max_scale=1e8):
"""
Compute the scaling transformations that best align the patches
Args:
max_scale: maximum allowed scale (all scales are clipped to the range [``1/max_scale``, ``max_scale``])
(default: 1e8)
Returns:
list of scales
"""
scaling_mat = self._transform_matrix(lambda ov1, ov2: relative_scale(ov1, ov2, max_scale), 1)
vec = self._synchronise(scaling_mat, 1)
vec = vec.flatten()
vec = np.abs(vec)
vec /= vec.mean()
vec = np.clip(vec, a_min=1/max_scale, a_max=max_scale, out=vec) # avoid blow-up
return vec
[docs] def rotate_patches(self, rotations=None):
"""align the rotation/reflection of all patches
Args:
rotations: If provided, apply the given transformations instead of synchronizing patch rotations
"""
if rotations is None:
rotations = (rot.T for rot in self.calc_synchronised_rotations())
for i, rot in enumerate(rotations):
self.patches[i].coordinates = self.patches[i].coordinates @ rot.T
# track transformations
self.rotations[i] = self.rotations[i] @ rot.T
self.shifts[i] = self.shifts[i] @ rot.T
[docs] def calc_synchronised_rotations(self):
"""Compute the orthogonal transformations that best align the patches"""
rots = self._transform_matrix(relative_orthogonal_transform, self.dim, symmetric_weights=True)
vecs = self._synchronise(rots, blocksize=self.dim, symmetric=True)
for mat in vecs:
mat[:] = nearest_orthogonal(mat)
return vecs
[docs] def translate_patches(self, translations=None):
"""align the patches by translation
Args:
translations: If provided, apply the given translations instead of synchronizing
"""
if translations is None:
translations = self.calc_synchronised_translations()
for i, t in enumerate(translations):
self.patches[i].coordinates += t
# keep track of transformations
self.shifts[i] += t
[docs] def calc_synchronised_translations(self):
"""Compute translations that best align the patches"""
b = np.empty((len(self.patch_overlap), self.dim))
row = []
col = []
val = []
for i, ((p1, p2), overlap) in enumerate(self.patch_overlap.items()):
row.append(i)
col.append(p1)
val.append(-1)
row.append(i)
col.append(p2)
val.append(1)
b[i, :] = np.mean(self.patches[p1].get_coordinates(overlap)
- self.patches[p2].get_coordinates(overlap), axis=0)
A = ss.coo_matrix((val, (row, col)), shape=(len(self.patch_overlap), self.n_patches), dtype=np.int8)
A = A.tocsr()
translations = np.empty((self.n_patches, self.dim))
for d in range(self.dim):
translations[:, d] = lsmr(A, b[:, d], atol=1e-16, btol=1e-16)[0]
# TODO: probably doesn't need to be that accurate, this is for testing
return translations
[docs] def mean_embedding(self):
"""
Compute node embeddings as the centroid over patch embeddings
"""
embedding = np.empty((self.n_nodes, self.dim))
for node, patch_list in enumerate(self.patch_index):
embedding[node] = np.mean([self.patches[p].get_coordinate(node) for p in patch_list], axis=0)
return embedding
[docs] def get_aligned_embedding(self, scale=False, realign=False):
"""Return the aligned embedding
Args:
scale (bool): if ``True``, rescale patches (default: ``False``)
realign (bool): if ``True``, recompute aligned embedding even if it already exists (default: ``False``)
Returns:
n_nodes x dim numpy array of embedding coordinates
"""
if realign or self._aligned_embedding is None:
if scale:
self.scale_patches()
self.rotate_patches()
self.translate_patches()
self._aligned_embedding = self.mean_embedding()
return self._aligned_embedding
[docs] def save_patches(self, filename):
"""
save patch embeddings to json file
Args:
filename: path to output file
"""
filename = ensure_extension(filename, '.json')
patch_dict = {str(i): {int(node): [float(c) for c in coord]
for node, coord in zip(patch.index, patch.coordinates)}
for i, patch in enumerate(self.patches)}
with open(filename, 'w') as f:
json.dump(patch_dict, f)
[docs] @classmethod
def load(cls, filename):
"""
restore ``AlignmentProblem`` from patch file
Args:
filename: path to patch file
"""
filename = ensure_extension(filename, '.json')
with open(filename) as f:
patch_dict = json.load(f)
patch_list = [None] * len(patch_dict)
for i, patch_data in patch_dict.items():
nodes = (int(n) for n in patch_data.keys())
coordinates = list(patch_data.values())
patch_list[int(i)] = Patch(nodes, coordinates)
return cls(patch_list)
[docs] def save_embedding(self, filename):
"""
save aligned embedding to json file
Args:
filename: output filename
"""
filename = ensure_extension(filename, '.json')
embedding = {str(i): c for i, c in enumerate(self.get_aligned_embedding())}
with open(filename, 'w') as f:
json.dump(embedding, f)
def __copy__(self):
"""return a copy of the alignment problem where all patches are copied."""
instance = self.__new__(type(self))
for key, value in self.__dict__.items():
instance.__dict__[key] = copy.copy(value)
instance.patches = [copy.copy(patch) for patch in self.patches]
return instance
def _synchronise(self, matrix: ss.spmatrix, blocksize=1, symmetric=False):
dim = matrix.shape[0]
if symmetric:
matrix = matrix + ss.eye(dim) # shift to ensure matrix is positive semi-definite for buckling mode
eigs, vecs = ss.linalg.eigsh(matrix, k=blocksize, v0=rg.normal(size=dim), which='LM',
sigma=2, mode='buckling')
# eigsh unreliable with multiple (clustered) eigenvalues, only buckling mode seems to help reliably
else:
# scaling is not symmetric but Perron-Frobenius applies
eigs, vecs = ss.linalg.eigs(matrix, k=blocksize, v0=rg.normal(size=dim))
eigs = eigs.real
vecs = vecs.real
order = np.argsort(eigs)
vecs = vecs[:, order[-1:-blocksize-1:-1]]
if self.verbose:
print(f'eigenvalues: {eigs}')
vecs.shape = (dim//blocksize, blocksize, blocksize)
return vecs
def _transform_matrix(self, transform: Callable[[np.ndarray, np.ndarray], Any], dim, symmetric_weights=False):
"""Calculate matrix of relative transformations between patches
Args:
transform: function to compute the relative transformation
dim: output dimension of transform should be `(dim, dim)`
symmetric_weights: if true use symmetric weighting (default: False)
"""
n = self.n_patches # number of patches
if dim != 1:
# construct matrix of rotations as a block-sparse-row matrix
data = np.empty(shape=(len(self.patch_overlap), dim, dim))
else:
data = np.empty(shape=(len(self.patch_overlap),))
weights = np.zeros(n)
indptr = np.zeros((n + 1,), dtype=int)
np.cumsum(self.patch_degrees, out=indptr[1:])
index = np.empty(shape=(len(self.patch_overlap),), dtype=int)
keys = sorted(self.patch_overlap.keys())
# TODO: this could be sped up by a factor of two by not computing rotations twice
for count, (i, j) in enumerate(keys):
if i == j:
element = np.eye(dim)
else:
overlap_idxs = self.patch_overlap[i, j]
# find positions of overlapping nodes in the two reference frames
overlap1 = self.patches[i].get_coordinates(overlap_idxs)
overlap2 = self.patches[j].get_coordinates(overlap_idxs)
element = transform(overlap1, overlap2)
weight = self.weight(i, j)
weights[i] += weight
element *= weight
data[count] = element
index[count] = j
# computed weighted average based on error weights
if symmetric_weights:
for i in range(n):
for ind in range(indptr[i], indptr[i + 1]):
data[ind] /= np.sqrt(weights[i]*weights[index[ind]])
else:
for i in range(n):
data[indptr[i]:indptr[i+1]] /= weights[i]
if dim == 1:
matrix = ss.csr_matrix((data, index, indptr), shape=(n, n))
else:
matrix = ss.bsr_matrix((data, index, indptr), shape=(dim*n, dim*n), blocksize=(dim, dim))
return matrix
[docs]class WeightedAlignmentProblem(AlignmentProblem):
"""
Variant of the local2global algorithm where patch edges are weighted according to the number of nodes in the overlap.
"""
[docs] def weight(self, i, j):
"""
compute weight for pair of patches
Args:
i: first patch index
j: second patch index
Returns:
number of shared nodes between patches `i` and `j`
"""
ov = len(self.patch_overlap[i, j])
return ov