# Copyright (c) 2021. Lucas G. S. Jeub
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import numpy as np
import torch
from pathlib import Path
from dask import delayed
from dask.distributed import worker_client, Client
from distributed import secede, rejoin
from local2global_embedding.clustering import Partition
from local2global_embedding.sparsify import resistance_sparsify
from local2global_embedding.utils import Timer
from .utils import mean_embedding, aligned_coords
from local2global_embedding.run.utils import ScriptParser
from filelock import SoftFileLock
[docs]
def get_aligned_embedding(patch_graph, patches, clusters, verbose=True, use_tmp=False, resparsify=0,
scale=False, rotate=True, translate=True, time=0.0):
if not clusters:
coords, ltime = aligned_coords(patches, patch_graph, verbose, use_tmp, scale, rotate, translate)
return coords, time + ltime
else:
cluster = clusters[0]
reduced_patch_graph = patch_graph.partition_graph(cluster)
if resparsify > 0:
reduced_patch_graph = resistance_sparsify(reduced_patch_graph, resparsify)
parts = Partition(cluster)
reduced_patches = []
for i, part in enumerate(parts):
local_patch_graph = patch_graph.subgraph(part)
local_patches = [patches[p] for p in part]
rpatch, rtime = aligned_coords(
patch_graph=local_patch_graph,
patches=local_patches,
verbose=verbose,
use_tmp=use_tmp)
reduced_patches.append(rpatch)
time += rtime
return get_aligned_embedding(reduced_patch_graph, reduced_patches, clusters[1:], verbose, use_tmp, resparsify,
scale, rotate, translate, time)
[docs]
def hierarchical_l2g_align_patches(patch_graph, shape, patches, output_file: Path, cluster_file=None, mmap=False,
verbose=False, use_tmp=False, resparsify=0, store_aligned_patches=False, scale=False,
rotate=True, translate=True):
if mmap:
def get_coords(aligned):
return mean_embedding(aligned.patches, shape, output_file, use_tmp)
else:
def get_coords(aligned):
coords = np.asarray(aligned.coordinates, dtype=np.float32)
np.save(output_file, np.asarray(coords, dtype=np.float32))
return coords
@delayed
def save_results(aligned, time):
coords = get_coords(aligned)
if store_aligned_patches:
if scale:
postfix = '_aligned_scaled_coords'
else:
postfix = '_aligned_coords'
for patch in aligned.patches:
f_name = patch.coordinates.filename
aligned_f_name = f_name.with_name(f_name.name.replace('_coords', postfix))
np.save(aligned_f_name, patch.coordinates)
timing_file = output_file.with_name(output_file.stem + "time.txt")
with SoftFileLock(timing_file.with_suffix(".lock")):
with open(timing_file, 'a') as f:
f.write(str(time) + "\n")
return coords
if cluster_file is not None:
clusters = torch.load(cluster_file)
else:
clusters = None
if isinstance(clusters, list) and len(clusters) > 1:
clusters = delayed(torch.load)(cluster_file)
aligned, time = get_aligned_embedding(
patch_graph=patch_graph, patches=patches, clusters=clusters[1:], verbose=verbose, use_tmp=use_tmp,
resparsify=resparsify, scale=scale, rotate=rotate, translate=translate)
else:
aligned, time = aligned_coords(patches, patch_graph, verbose, use_tmp, scale, rotate, translate)
return save_results(aligned, time)
if __name__ == '__main__':
client = Client()
ScriptParser(hierarchical_l2g_align_patches).run()