# Copyright (c) 2021. Lucas G. S. Jeub
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import json
from bisect import bisect_left
from pathlib import Path
import asyncio
import sys
from collections import UserDict
import argparse
from inspect import signature
import typing
import time
from ast import literal_eval
from traceback import print_exception
from runpy import run_path
import operator
import torch
from docstring_parser import parse as parse_doc
from filelock import SoftFileLock
from atomicwrites import atomic_write
from dask.distributed import Client, get_client
from local2global_embedding.classfication import ClassificationProblem
_dataloaders = {} #: dataloaders
_classification_loader = {}
[docs]
def dataloader(name):
"""
decorator for registering dataloader functions
Args:
name: data set name
"""
def loader(func):
_dataloaders[name] = func
return func
return loader
[docs]
def classificationloader(name):
"""
decorator for registering classification data loaders
Args:
name: data set name
Returns:
"""
def loader(func):
_classification_loader[name] = func
return func
return loader
[docs]
def load_data(name, root='/tmp', restrict_lcc=False, **kwargs):
"""
load data set
Args:
name: name of data set (one of {names})
root: root dir to store downloaded data (default '/tmp')
Returns:
graph data
"""
root = Path(root).expanduser()
data = _dataloaders[name](root, **kwargs)
if restrict_lcc:
data = data.lcc(relabel=True)
return data
[docs]
def load_classification_problem(name, graph, root='/tmp', **class_args):
root = Path(root).expanduser()
y, split = _classification_loader[name](root=root, **class_args)
if graph.has_node_labels():
index = graph.nodes
index_map = torch.full(y.shape, -1, dtype=torch.long)
index_map[index] = torch.arange(len(index), dtype=torch.long)
y = y[index]
for key, value in split.items():
mapped_index = index_map[value]
split[key] = mapped_index[mapped_index >= 0]
return ClassificationProblem(y, split=split)
load_data.__doc__ = load_data.__doc__.format(names=list(_dataloaders.keys()))
[docs]
def cluster_string(cluster='metis', num_clusters=10, num_iters: int=None, beta=0.1, levels=1):
if cluster == 'louvain':
cluster_string = 'louvain'
elif cluster == 'distributed':
cluster_string = f'distributed_beta{beta}_it{num_iters}'
elif cluster == 'fennel':
cluster_string = f"fennel_n{num_clusters}_it{num_iters}"
elif cluster == 'metis':
cluster_string = f"metis_n{num_clusters}"
else:
raise RuntimeError(f"Unknown cluster method '{cluster}'.")
if levels > 1:
cluster_string += f'_hc{levels}'
return cluster_string
[docs]
def patch_folder_name(name: str, min_overlap: int, target_overlap: int, cluster='metis',
num_clusters=10, num_iters: int=None, beta=0.1, levels=1,
sparsify='resistance', target_patch_degree=4.0, gamma=0.0):
if sparsify == 'resistance':
sp_string = f"resistance_deg{target_patch_degree}"
elif sparsify == 'rmst':
sp_string = f"rmst_gamma{gamma}"
elif sparsify == 'none':
sp_string = "no_sparsify"
elif sparsify == 'sample':
sp_string = f'sample_deg{target_patch_degree}'
elif sparsify == 'neighbors':
sp_string = f'neighbors_deg{target_patch_degree}'
else:
raise RuntimeError(f"Unknown sparsification method '{sparsify}'.")
cl_string = cluster_string(cluster, num_clusters, num_iters, beta, levels)
return f'{name}_{cl_string}_{sp_string}_mo{min_overlap}_to{target_overlap}_patches'
[docs]
def cluster_file_name(name, cluster='metis', num_clusters=10, num_iters: int=None, beta=0.1, levels=1):
cl_string = cluster_string(cluster, num_clusters, num_iters, beta, levels)
return f'{name}_{cl_string}_clusters.pt'
[docs]
class NoLock:
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
pass
[docs]
def acquire(self):
pass
[docs]
def release(self):
pass
[docs]
class SyncDict(UserDict):
"""
Class for keeping json-backed dict with locking for sync
"""
[docs]
def load(self):
"""
restore results from file
Returns:
populated ResultsDict
"""
with self._lock:
with open(self.filename) as f:
self.data = json.load(f)
[docs]
def save(self):
"""
dump contents to json file
Args:
filename: output file path
"""
with self._lock:
with atomic_write(self.filename, overwrite=True) as f: # this should avoid any chance of loosing existing data
json.dump(self.data, f)
[docs]
def __init__(self, filename, lock=True):
"""
initialise empty ResultsDict
Args:
filename: file to use for storage
lock: set lock=False to avoid locking (use wisely)
"""
super().__init__()
self.filename = Path(filename)
if lock:
self._lock = SoftFileLock(self.filename.with_suffix('.lock'), timeout=10)
else:
self._lock = NoLock() # implements lock interface without doing anything
with self._lock:
if not self.filename.is_file():
self.save()
else:
self.load()
def __enter__(self):
self._lock.acquire()
self.load()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.save()
self._lock.release()
[docs]
class ResultsDict(UserDict):
"""
Class for keeping track of results
"""
[docs]
def load(self):
"""
restore results from file
Args:
filename: input json file
replace: set the replace attribute
Returns:
populated ResultsDict
"""
with self._lock:
with open(self.filename) as f:
self.data = json.load(f)
[docs]
def save(self):
"""
dump contents to json file
Args:
filename: output file path
"""
with self._lock:
with atomic_write(self.filename, overwrite=True) as f: # this should avoid any chance of loosing existing data
json.dump(self.data, f)
[docs]
def __init__(self, filename, replace=False, lock=True):
"""
initialise empty ResultsDict
Args:
replace: set the replace attribute (default: ``False``)
"""
super().__init__()
self.filename = Path(filename)
if lock:
self._lock = SoftFileLock(self.filename.with_suffix('.lock'), timeout=10)
else:
self._lock = NoLock() # implements lock interface without doing anything
with self._lock:
if not self.filename.is_file():
self.data = {'dims': [], 'runs': []}
self.save()
else:
self.load()
self.replace = replace #: if ``True``, updates replace existing data, if ``False``, updates append data
def __enter__(self):
self._lock.acquire()
self.load()
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.save()
self._lock.release()
def _update_index(self, index, replace=False, **kwargs):
"""
update data for a given index
Args:
index: integer index into data lists
aucs: new auc value
args: new args data (optional)
"""
for key, val in kwargs.items():
if key not in self:
self.data[key] = [[] for _ in self['dims']]
if replace or self.replace:
self[key][index] = [val]
else:
self[key][index].append(val)
if not (replace or self.replace):
self['runs'][index] += 1
def _insert_index(self, index: int, dim: int, **kwargs):
"""
insert new data at index
Args:
index: integer index into data lists
dim: data dimension for index
aucs: new auc values
args: new args data (optional)
"""
self['dims'].insert(index, dim)
for key in self:
if key not in {'dims', 'runs'}: # these only store single values per dimension
self[key].insert(index, [])
for key, val in kwargs.items():
if key in self:
self[key][index].append(val)
else:
self.data[key] = [[] for _ in self['dims']]
self[key][index].append(val)
self['runs'].insert(index, 1)
def _index(self, dim):
return bisect_left(self['dims'], dim)
[docs]
def update_dim(self, dim, replace=False, **kwargs):
"""
update data for given dimension
Args:
dim: dimension to update
auc: new auc value
loss: new loss value
args: new args data (optional)
if ``self.contains_dim(dim) == True``, behaviour depends on the value of
``self.replace``
"""
index = self._index(dim)
if index < len(self['dims']) and self['dims'][index] == dim:
self._update_index(index, replace=replace, **kwargs)
else:
self._insert_index(index, dim, **kwargs)
[docs]
def delete_dim(self, dim):
index = self._index(dim)
if index < len(self['dims']) and self['dims'][index] == dim:
for v in self.values():
del v[index]
[docs]
def max(self, field, dim=None):
"""
return maximum auc values
Args:
field: field to take maximum over
dim: if ``dim=None``, return list of values for all dimension, else only return maximum value for ``dim``.
"""
if field not in self:
if dim is None:
return [-float('inf') for _ in self['dims']]
else:
return -float('inf')
if dim is None:
return [max(val) for val in self[field]]
else:
index = bisect_left(self['dims'], dim)
if index < len(self['dims']) and self['dims'][index] == dim:
return max(self[field][index])
else:
return -float('inf')
[docs]
def min(self, field, dim=None):
if field not in self:
if dim is None:
return [float('inf') for _ in self['dims']]
else:
return float('inf')
if dim is None:
return [min(val) for val in self[field]]
else:
index = bisect_left(self['dims'], dim)
if index < len(self['dims']) and self['dims'][index] == dim:
return min(self[field][index])
else:
return float('inf')
[docs]
def get(self, item, dim=None, default=None):
if item in self:
if dim is None:
return self[item]
elif self.contains_dim(dim):
index = self._index(dim)
return self[item][index]
return default
[docs]
def contains_dim(self, dim):
"""
equivalent to ``dim in self['dims']``
"""
index = bisect_left(self['dims'], dim)
return index < len(self['dims']) and self['dims'][index] == dim
[docs]
def reduce_to_dims(self, dims):
"""
remove all data for dimensions not in ``dims``
Args:
dims: list of dimensions to keep
"""
index = [i for i, d in enumerate(dims) if self.contains_dim(d)]
for key1 in self.data:
if isinstance(self.data[key1], list):
self.data[key1] = [self[key1][i] for i in index]
return self
[docs]
def runs(self, dim=None):
"""
return the number of runs
Args:
dim: if ``dim is None``, return list of number of runs for all dimension, else return number of
runs for dimension ``dim``.
"""
if dim is None:
return self['runs']
else:
index = bisect_left(self['dims'], dim)
if index < len(self['dims']) and self['dims'][index] == dim:
return self['runs'][index]
else:
return 0
[docs]
class Throttler:
[docs]
def __init__(self, min_interval=0):
self.min_interval=min_interval
self.next_run = time.monotonic()
[docs]
async def submit_ok(self):
now = time.monotonic()
if now > self.next_run + self.min_interval:
self.next_run = now
return
else:
self.next_run += self.min_interval
await asyncio.sleep(self.next_run-now)
return
[docs]
async def run_script(script_name, _cmd_prefix=None, _task_queue: asyncio.Queue = None, _throttler: Throttler = None,
_stderr=False,
**kwargs):
args = []
if _cmd_prefix is not None:
args.extend(_cmd_prefix.split())
args.extend(['python', '-m', f'local2global_embedding.run.scripts.{script_name}'])
args.extend(f'--{key}={value}' for key, value in kwargs.items())
if _task_queue is not None:
await _task_queue.put(args) # limit number of simultaneous tasks
if _throttler is not None:
await _throttler.submit_ok() # limit task creation frequency
if _stderr:
stdout = sys.stderr # redirect all output to stderr
else:
stdout = None
proc = await asyncio.create_subprocess_exec(*args, stdout=stdout)
await proc.communicate()
if _task_queue is not None:
await _task_queue.get()
_task_queue.task_done()
[docs]
class CSVList:
[docs]
def __init__(self, dtype=str):
self.dtype=dtype
def __call__(self, input: str):
return [self.dtype(s) for s in input.split(',')]
[docs]
class BooleanString:
[docs]
def __new__(cls, s):
if s not in {'False', 'True'}:
raise ValueError('Not a valid boolean string')
return s == 'True'
[docs]
class Union:
[docs]
def __init__(self, types):
self.types = types
def __call__(self, value: str):
if value == 'None' and type(None) in self.types:
return None
for t in self.types:
try:
return t(value)
except Exception:
pass
raise RuntimeError(f'Cannot parse argument {value}')
[docs]
class Argument:
"""
Argument wrapper for ScriptParser
"""
[docs]
def __init__(self, name='', parameter=None, allow_reset=False):
"""
Initialize Argument
Args:
name: argument name
parameter: signature parameter (optional, used to get default value if specified)
"""
self.name = name
self.allow_reset=allow_reset
self.required = parameter is None or parameter.default is parameter.empty
self.is_set = False
if not self.required:
self._value = parameter.default
else:
self._value = None
def __call__(self, input_str):
"""
parse argument string
Args:
input_str: string to evaluate
Returns: self
Tries to parse `input_str` as python code using `ast.literal_eval`. If this fails, sets value to `input_str`.
"""
if not self.is_set or self.allow_reset:
self.is_set = True
try:
val = literal_eval(input_str)
except Exception: # could not interpret as python literal, assume it is a bare string argument
val = input_str
self._value = val
return self
else:
raise RuntimeError(f"Tried to set value for argument {self.name!r} multiple times, old value: {self._value}, new value: {input_str}")
@property
def value(self):
if not self.required or self.is_set:
return self._value
else:
raise RuntimeError(f"Missing value for required argument {self.name!r}")
def __repr__(self):
repr_str = f'{self.__class__.__name__}(name={self.name!r})'
if self.required:
repr_str += ', required'
if self.is_set:
repr_str += f', value={self.value!r}'
elif not self.required:
repr_str += f', default={self.value!r}'
return repr_str
[docs]
class ScriptParser:
"""
Build a command-line interface to a python function
Inspects the function signature to create command-line arguments. It converts
argument `arg` to long option `--arg`. Parsing is similar to python, supporting
mix of positional and named arguments (note its possible to specify named arguments before positional arguments).
Also supports use of `*args` and `**kwargs`.
Help messages are constructed by parsing the doc-string of the wrapped function.
Can be used as a decorator if function should only be used as a script
"""
[docs]
def __init__(self, func, allow_reset=False):
"""
Wrap `func` as a command-line interface
Args:
func: Callable
"""
self.func = func
self.parser = argparse.ArgumentParser(prog=func.__name__)
self.allow_reset = allow_reset
self.var_pos = False
self.var_keyword = False
self.arguments = []
sig = signature(func).parameters
docstring = parse_doc(func.__doc__)
help = {p.arg_name: p.description for p in docstring.params}
self.parser.description = docstring.short_description
self.parser.add_argument('_pos', nargs='*')
for name, parameter in sig.items():
if parameter.kind == parameter.VAR_POSITIONAL:
if not self.var_pos:
self.var_pos = True
else:
raise RuntimeError('Only expected a single *args')
elif parameter.kind == parameter.VAR_KEYWORD:
if not self.var_keyword:
self.var_keyword = True
else:
raise RuntimeError('Only expected a single **kwargs')
else:
arg = Argument(name, parameter, allow_reset)
self.arguments.append(arg)
if arg.required:
self.parser.add_argument(f'--{name}', type=arg, default=arg, help=help.get(name, name))
else:
help_str = f'{help.get(name, name)} (default: {arg.value!r})'
self.parser.add_argument("--{}".format(name), type=arg, default=arg, help=help_str)
[docs]
def parse(self, args=None):
if args is None:
args = sys.argv[1:]
if self.var_keyword:
arg_res, unknown = self.parser.parse_known_args(args)
kwargs = vars(arg_res)
unknown_parser = argparse.ArgumentParser()
for arg in unknown:
if arg.startswith("--"):
name = arg[2:].split('=', 1)[0]
new_arg = Argument(name, allow_reset=self.allow_reset)
unknown_parser.add_argument(f'--{name}', type=new_arg, default=new_arg)
self.arguments.append(new_arg)
kwargs.update(vars(unknown_parser.parse_args(unknown)))
else:
kwargs = vars(self.parser.parse_args(args))
args = []
pos_args = kwargs.pop('_pos')
for arg, val in zip(self.arguments, pos_args):
if val.startswith('--'):
raise RuntimeError(f'Unknown keyword argument {val}')
if not arg.is_set:
arg = arg(val)
args.append(arg.value)
kwargs.pop(arg.name)
else:
break
if self.var_pos:
args.extend(Argument()(val).value for val in pos_args[len(args):])
else:
if len(args) != len(pos_args):
raise RuntimeError(f'Too many positional arguments specified. {pos_args=}')
for name, value in kwargs.items():
kwargs[name] = value.value
return args, kwargs
[docs]
def run(self, args=None):
"""
run the wrapped function with arguments passed on sys.argv or as list of string arguments
"""
args, kwargs = self.parse(args)
self.func(*args, **kwargs)
def __call__(self, args=None):
self.run(args)
[docs]
def watch_progress(tasks):
for c in tasks:
if c.status == 'error':
print(f'{c} errored')
e = c.exception()
print_exception(type(e), e, c.traceback())
else:
print(f'{c} complete')
del c
[docs]
def get_or_init_client(init_cluster=True):
try:
client = get_client()
except ValueError:
if init_cluster:
print('setting up cluster')
cluster_init_path = Path().home() / '.config' / 'dask' / 'cluster_init.py'
kwargs = run_path(cluster_init_path)
client = Client(kwargs['cluster'])
else:
print('launching default client')
client = Client()
return client
[docs]
def dask_unpack(client, result_future, n):
"""
Unpack dask future of iterable
Args:
client: dask client
result_future: future to unpack
n: number of items to unpack
Returns:
list of unpacked futures
"""
return [client.submit(operator.getitem, result_future, i) for i in range(n)]