Source code for

#  Copyright (c) 2021. Lucas G. S. Jeub
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
import json
from bisect import bisect_left
from pathlib import Path
import asyncio
import sys
from collections import UserDict
import argparse
from inspect import signature
import typing
import time
from ast import literal_eval
from traceback import print_exception
from runpy import run_path
import operator

import torch
from docstring_parser import parse as parse_doc
from filelock import SoftFileLock
from atomicwrites import atomic_write
from dask.distributed import Client, get_client

from local2global_embedding.classfication import ClassificationProblem

_dataloaders = {}  #: dataloaders
_classification_loader = {}

[docs] def dataloader(name): """ decorator for registering dataloader functions Args: name: data set name """ def loader(func): _dataloaders[name] = func return func return loader
[docs] def classificationloader(name): """ decorator for registering classification data loaders Args: name: data set name Returns: """ def loader(func): _classification_loader[name] = func return func return loader
[docs] def load_data(name, root='/tmp', restrict_lcc=False, **kwargs): """ load data set Args: name: name of data set (one of {names}) root: root dir to store downloaded data (default '/tmp') Returns: graph data """ root = Path(root).expanduser() data = _dataloaders[name](root, **kwargs) if restrict_lcc: data = data.lcc(relabel=True) return data
[docs] def load_classification_problem(name, graph, root='/tmp', **class_args): root = Path(root).expanduser() y, split = _classification_loader[name](root=root, **class_args) if graph.has_node_labels(): index = graph.nodes index_map = torch.full(y.shape, -1, dtype=torch.long) index_map[index] = torch.arange(len(index), dtype=torch.long) y = y[index] for key, value in split.items(): mapped_index = index_map[value] split[key] = mapped_index[mapped_index >= 0] return ClassificationProblem(y, split=split)
load_data.__doc__ = load_data.__doc__.format(names=list(_dataloaders.keys()))
[docs] def cluster_string(cluster='metis', num_clusters=10, num_iters: int=None, beta=0.1, levels=1): if cluster == 'louvain': cluster_string = 'louvain' elif cluster == 'distributed': cluster_string = f'distributed_beta{beta}_it{num_iters}' elif cluster == 'fennel': cluster_string = f"fennel_n{num_clusters}_it{num_iters}" elif cluster == 'metis': cluster_string = f"metis_n{num_clusters}" else: raise RuntimeError(f"Unknown cluster method '{cluster}'.") if levels > 1: cluster_string += f'_hc{levels}' return cluster_string
[docs] def patch_folder_name(name: str, min_overlap: int, target_overlap: int, cluster='metis', num_clusters=10, num_iters: int=None, beta=0.1, levels=1, sparsify='resistance', target_patch_degree=4.0, gamma=0.0): if sparsify == 'resistance': sp_string = f"resistance_deg{target_patch_degree}" elif sparsify == 'rmst': sp_string = f"rmst_gamma{gamma}" elif sparsify == 'none': sp_string = "no_sparsify" elif sparsify == 'sample': sp_string = f'sample_deg{target_patch_degree}' elif sparsify == 'neighbors': sp_string = f'neighbors_deg{target_patch_degree}' else: raise RuntimeError(f"Unknown sparsification method '{sparsify}'.") cl_string = cluster_string(cluster, num_clusters, num_iters, beta, levels) return f'{name}_{cl_string}_{sp_string}_mo{min_overlap}_to{target_overlap}_patches'
[docs] def cluster_file_name(name, cluster='metis', num_clusters=10, num_iters: int=None, beta=0.1, levels=1): cl_string = cluster_string(cluster, num_clusters, num_iters, beta, levels) return f'{name}_{cl_string}'
[docs] class NoLock: def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): pass
[docs] def acquire(self): pass
[docs] def release(self): pass
[docs] class SyncDict(UserDict): """ Class for keeping json-backed dict with locking for sync """
[docs] def load(self): """ restore results from file Returns: populated ResultsDict """ with self._lock: with open(self.filename) as f: = json.load(f)
[docs] def save(self): """ dump contents to json file Args: filename: output file path """ with self._lock: with atomic_write(self.filename, overwrite=True) as f: # this should avoid any chance of loosing existing data json.dump(, f)
[docs] def __init__(self, filename, lock=True): """ initialise empty ResultsDict Args: filename: file to use for storage lock: set lock=False to avoid locking (use wisely) """ super().__init__() self.filename = Path(filename) if lock: self._lock = SoftFileLock(self.filename.with_suffix('.lock'), timeout=10) else: self._lock = NoLock() # implements lock interface without doing anything with self._lock: if not self.filename.is_file(): else: self.load()
def __enter__(self): self._lock.acquire() self.load() return self def __exit__(self, exc_type, exc_val, exc_tb): self._lock.release()
[docs] class ResultsDict(UserDict): """ Class for keeping track of results """
[docs] def load(self): """ restore results from file Args: filename: input json file replace: set the replace attribute Returns: populated ResultsDict """ with self._lock: with open(self.filename) as f: = json.load(f)
[docs] def save(self): """ dump contents to json file Args: filename: output file path """ with self._lock: with atomic_write(self.filename, overwrite=True) as f: # this should avoid any chance of loosing existing data json.dump(, f)
[docs] def __init__(self, filename, replace=False, lock=True): """ initialise empty ResultsDict Args: replace: set the replace attribute (default: ``False``) """ super().__init__() self.filename = Path(filename) if lock: self._lock = SoftFileLock(self.filename.with_suffix('.lock'), timeout=10) else: self._lock = NoLock() # implements lock interface without doing anything with self._lock: if not self.filename.is_file(): = {'dims': [], 'runs': []} else: self.load() self.replace = replace #: if ``True``, updates replace existing data, if ``False``, updates append data
def __enter__(self): self._lock.acquire() self.load() return self def __exit__(self, exc_type, exc_val, exc_tb): self._lock.release() def _update_index(self, index, replace=False, **kwargs): """ update data for a given index Args: index: integer index into data lists aucs: new auc value args: new args data (optional) """ for key, val in kwargs.items(): if key not in self:[key] = [[] for _ in self['dims']] if replace or self.replace: self[key][index] = [val] else: self[key][index].append(val) if not (replace or self.replace): self['runs'][index] += 1 def _insert_index(self, index: int, dim: int, **kwargs): """ insert new data at index Args: index: integer index into data lists dim: data dimension for index aucs: new auc values args: new args data (optional) """ self['dims'].insert(index, dim) for key in self: if key not in {'dims', 'runs'}: # these only store single values per dimension self[key].insert(index, []) for key, val in kwargs.items(): if key in self: self[key][index].append(val) else:[key] = [[] for _ in self['dims']] self[key][index].append(val) self['runs'].insert(index, 1) def _index(self, dim): return bisect_left(self['dims'], dim)
[docs] def update_dim(self, dim, replace=False, **kwargs): """ update data for given dimension Args: dim: dimension to update auc: new auc value loss: new loss value args: new args data (optional) if ``self.contains_dim(dim) == True``, behaviour depends on the value of ``self.replace`` """ index = self._index(dim) if index < len(self['dims']) and self['dims'][index] == dim: self._update_index(index, replace=replace, **kwargs) else: self._insert_index(index, dim, **kwargs)
[docs] def delete_dim(self, dim): index = self._index(dim) if index < len(self['dims']) and self['dims'][index] == dim: for v in self.values(): del v[index]
[docs] def max(self, field, dim=None): """ return maximum auc values Args: field: field to take maximum over dim: if ``dim=None``, return list of values for all dimension, else only return maximum value for ``dim``. """ if field not in self: if dim is None: return [-float('inf') for _ in self['dims']] else: return -float('inf') if dim is None: return [max(val) for val in self[field]] else: index = bisect_left(self['dims'], dim) if index < len(self['dims']) and self['dims'][index] == dim: return max(self[field][index]) else: return -float('inf')
[docs] def min(self, field, dim=None): if field not in self: if dim is None: return [float('inf') for _ in self['dims']] else: return float('inf') if dim is None: return [min(val) for val in self[field]] else: index = bisect_left(self['dims'], dim) if index < len(self['dims']) and self['dims'][index] == dim: return min(self[field][index]) else: return float('inf')
[docs] def get(self, item, dim=None, default=None): if item in self: if dim is None: return self[item] elif self.contains_dim(dim): index = self._index(dim) return self[item][index] return default
[docs] def contains_dim(self, dim): """ equivalent to ``dim in self['dims']`` """ index = bisect_left(self['dims'], dim) return index < len(self['dims']) and self['dims'][index] == dim
[docs] def reduce_to_dims(self, dims): """ remove all data for dimensions not in ``dims`` Args: dims: list of dimensions to keep """ index = [i for i, d in enumerate(dims) if self.contains_dim(d)] for key1 in if isinstance([key1], list):[key1] = [self[key1][i] for i in index] return self
[docs] def runs(self, dim=None): """ return the number of runs Args: dim: if ``dim is None``, return list of number of runs for all dimension, else return number of runs for dimension ``dim``. """ if dim is None: return self['runs'] else: index = bisect_left(self['dims'], dim) if index < len(self['dims']) and self['dims'][index] == dim: return self['runs'][index] else: return 0
[docs] class Throttler:
[docs] def __init__(self, min_interval=0): self.min_interval=min_interval self.next_run = time.monotonic()
[docs] async def submit_ok(self): now = time.monotonic() if now > self.next_run + self.min_interval: self.next_run = now return else: self.next_run += self.min_interval await asyncio.sleep(self.next_run-now) return
[docs] async def run_script(script_name, _cmd_prefix=None, _task_queue: asyncio.Queue = None, _throttler: Throttler = None, _stderr=False, **kwargs): args = [] if _cmd_prefix is not None: args.extend(_cmd_prefix.split()) args.extend(['python', '-m', f'{script_name}']) args.extend(f'--{key}={value}' for key, value in kwargs.items()) if _task_queue is not None: await _task_queue.put(args) # limit number of simultaneous tasks if _throttler is not None: await _throttler.submit_ok() # limit task creation frequency if _stderr: stdout = sys.stderr # redirect all output to stderr else: stdout = None proc = await asyncio.create_subprocess_exec(*args, stdout=stdout) await proc.communicate() if _task_queue is not None: await _task_queue.get() _task_queue.task_done()
[docs] class CSVList:
[docs] def __init__(self, dtype=str): self.dtype=dtype
def __call__(self, input: str): return [self.dtype(s) for s in input.split(',')]
[docs] class BooleanString:
[docs] def __new__(cls, s): if s not in {'False', 'True'}: raise ValueError('Not a valid boolean string') return s == 'True'
[docs] class Union:
[docs] def __init__(self, types): self.types = types
def __call__(self, value: str): if value == 'None' and type(None) in self.types: return None for t in self.types: try: return t(value) except Exception: pass raise RuntimeError(f'Cannot parse argument {value}')
[docs] class Argument: """ Argument wrapper for ScriptParser """
[docs] def __init__(self, name='', parameter=None, allow_reset=False): """ Initialize Argument Args: name: argument name parameter: signature parameter (optional, used to get default value if specified) """ = name self.allow_reset=allow_reset self.required = parameter is None or parameter.default is parameter.empty self.is_set = False if not self.required: self._value = parameter.default else: self._value = None
def __call__(self, input_str): """ parse argument string Args: input_str: string to evaluate Returns: self Tries to parse `input_str` as python code using `ast.literal_eval`. If this fails, sets value to `input_str`. """ if not self.is_set or self.allow_reset: self.is_set = True try: val = literal_eval(input_str) except Exception: # could not interpret as python literal, assume it is a bare string argument val = input_str self._value = val return self else: raise RuntimeError(f"Tried to set value for argument {!r} multiple times, old value: {self._value}, new value: {input_str}") @property def value(self): if not self.required or self.is_set: return self._value else: raise RuntimeError(f"Missing value for required argument {!r}") def __repr__(self): repr_str = f'{self.__class__.__name__}(name={!r})' if self.required: repr_str += ', required' if self.is_set: repr_str += f', value={self.value!r}' elif not self.required: repr_str += f', default={self.value!r}' return repr_str
[docs] class ScriptParser: """ Build a command-line interface to a python function Inspects the function signature to create command-line arguments. It converts argument `arg` to long option `--arg`. Parsing is similar to python, supporting mix of positional and named arguments (note its possible to specify named arguments before positional arguments). Also supports use of `*args` and `**kwargs`. Help messages are constructed by parsing the doc-string of the wrapped function. Can be used as a decorator if function should only be used as a script """
[docs] def __init__(self, func, allow_reset=False): """ Wrap `func` as a command-line interface Args: func: Callable """ self.func = func self.parser = argparse.ArgumentParser(prog=func.__name__) self.allow_reset = allow_reset self.var_pos = False self.var_keyword = False self.arguments = [] sig = signature(func).parameters docstring = parse_doc(func.__doc__) help = {p.arg_name: p.description for p in docstring.params} self.parser.description = docstring.short_description self.parser.add_argument('_pos', nargs='*') for name, parameter in sig.items(): if parameter.kind == parameter.VAR_POSITIONAL: if not self.var_pos: self.var_pos = True else: raise RuntimeError('Only expected a single *args') elif parameter.kind == parameter.VAR_KEYWORD: if not self.var_keyword: self.var_keyword = True else: raise RuntimeError('Only expected a single **kwargs') else: arg = Argument(name, parameter, allow_reset) self.arguments.append(arg) if arg.required: self.parser.add_argument(f'--{name}', type=arg, default=arg, help=help.get(name, name)) else: help_str = f'{help.get(name, name)} (default: {arg.value!r})' self.parser.add_argument("--{}".format(name), type=arg, default=arg, help=help_str)
[docs] def parse(self, args=None): if args is None: args = sys.argv[1:] if self.var_keyword: arg_res, unknown = self.parser.parse_known_args(args) kwargs = vars(arg_res) unknown_parser = argparse.ArgumentParser() for arg in unknown: if arg.startswith("--"): name = arg[2:].split('=', 1)[0] new_arg = Argument(name, allow_reset=self.allow_reset) unknown_parser.add_argument(f'--{name}', type=new_arg, default=new_arg) self.arguments.append(new_arg) kwargs.update(vars(unknown_parser.parse_args(unknown))) else: kwargs = vars(self.parser.parse_args(args)) args = [] pos_args = kwargs.pop('_pos') for arg, val in zip(self.arguments, pos_args): if val.startswith('--'): raise RuntimeError(f'Unknown keyword argument {val}') if not arg.is_set: arg = arg(val) args.append(arg.value) kwargs.pop( else: break if self.var_pos: args.extend(Argument()(val).value for val in pos_args[len(args):]) else: if len(args) != len(pos_args): raise RuntimeError(f'Too many positional arguments specified. {pos_args=}') for name, value in kwargs.items(): kwargs[name] = value.value return args, kwargs
[docs] def run(self, args=None): """ run the wrapped function with arguments passed on sys.argv or as list of string arguments """ args, kwargs = self.parse(args) self.func(*args, **kwargs)
def __call__(self, args=None):
[docs] def watch_progress(tasks): for c in tasks: if c.status == 'error': print(f'{c} errored') e = c.exception() print_exception(type(e), e, c.traceback()) else: print(f'{c} complete') del c
[docs] def get_or_init_client(init_cluster=True): try: client = get_client() except ValueError: if init_cluster: print('setting up cluster') cluster_init_path = Path().home() / '.config' / 'dask' / '' kwargs = run_path(cluster_init_path) client = Client(kwargs['cluster']) else: print('launching default client') client = Client() return client
[docs] def dask_unpack(client, result_future, n): """ Unpack dask future of iterable Args: client: dask client result_future: future to unpack n: number of items to unpack Returns: list of unpacked futures """ return [client.submit(operator.getitem, result_future, i) for i in range(n)]