Source code for local2global_embedding.run.scripts.l2g_align_patches

#  Copyright (c) 2021. Lucas G. S. Jeub
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.

from pathlib import Path
from typing import Optional
from shutil import copyfile
from tempfile import NamedTemporaryFile, TemporaryFile
from local2global_embedding.utils import Timer

import torch
import numpy as np
from numpy.lib.format import open_memmap
from filelock import SoftFileLock
from tqdm import tqdm

from local2global.utils import WeightedAlignmentProblem, SVDAlignmentProblem
from local2global.patch import Patch, FilePatch
from local2global_embedding.run.utils import ScriptParser


[docs] def l2g_align_patches(patch_folder: str, basename: str, dim: int, criterion: str, mmap=False, use_tmp=False, verbose=False): print(f'computing aligned embedding for {patch_folder}/{basename}_d{dim}') patch_folder = Path(patch_folder) patch_graph = torch.load(patch_folder / 'patch_graph.pt', map_location='cpu') with SoftFileLock(patch_folder / f'{basename}_d{dim}_{criterion}_coords.lock', timeout=10): # only one task at a time patch_list = [] for i in tqdm(range(patch_graph.num_nodes), desc='loading patch data', smoothing=0): node_file = patch_folder / f'patch{i}_index.npy' coords_file = patch_folder / f'{basename}_patch{i}_d{dim}_best_{criterion}_coords.npy' if node_file.is_file(): nodes = np.load(node_file) else: patch_file = patch_folder / f'patch{i}_data.pt' patch = torch.load(patch_file, map_location='cpu') nodes = patch.nodes if mmap: if use_tmp: coords_file_tmp = NamedTemporaryFile(delete=False) coords_file_tmp.close() copyfile(coords_file, coords_file_tmp.name) coords_file = coords_file_tmp.name patch_list.append(FilePatch(nodes, coords_file)) else: with SoftFileLock(f'{basename}_patch{i}_info.lock', timeout=10): coords = np.load(coords_file) patch_list.append(Patch(nodes, coords)) print('initialising alignment problem') timer = Timer() with timer: prob = SVDAlignmentProblem(patch_list, patch_edges=patch_graph.edges(), copy_data=False, verbose=verbose) patched_embedding_file = patch_folder / f'{basename}_d{dim}_{criterion}_coords.npy' patched_embedding_file_nt = patch_folder / f'{basename}_d{dim}_{criterion}_ntcoords.npy' if mmap: print('computing ntcoords using mmap') if use_tmp: print('using tmp buffer') with TemporaryFile() as f: buffer = np.memmap(f, dtype=np.float32, shape=(prob.n_nodes, prob.dim)) prob.mean_embedding(buffer) np.save(patched_embedding_file_nt, buffer) else: out = open_memmap(patched_embedding_file_nt, mode='w+', shape=(prob.n_nodes, prob.dim), dtype=np.float32) prob.mean_embedding(out) out.flush() else: print('computing ntcoords') out = np.zeros(shape=(prob.n_nodes, prob.dim), dtype=np.float32) ntcoords = prob.mean_embedding(out) np.save(patched_embedding_file_nt, ntcoords) if mmap: print('computing aligned coords using mmap') if use_tmp: with TemporaryFile() as f: buffer = np.memmap(f, dtype=np.float32, shape=(prob.n_nodes, prob.dim)) with timer: prob.align_patches().mean_embedding(buffer) np.save(patched_embedding_file, buffer) else: out = open_memmap(patched_embedding_file, mode='w+', shape=(prob.n_nodes, prob.dim), dtype=np.float32) with timer: prob.align_patches().mean_embedding(out) out.flush() else: print('computing aligned coords') out = np.empty(shape=(prob.n_nodes, prob.dim), dtype=np.float32) with timer: coords = prob.align_patches().mean_embedding(out) np.save(patched_embedding_file, coords) with open(patched_embedding_file.with_name(patched_embedding_file.stem + "timing.txt")) as f: f.write(timer.total)
if __name__ == '__main__': ScriptParser(l2g_align_patches).run()