# Copyright (c) 2021. Lucas G. S. Jeub
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
"""Training run script"""
import os
from filelock import SoftFileLock, FileLock
from local2global_embedding.network import TGraph
from local2global_embedding.run.once_per_worker import once_per_worker
from weakref import finalize
from local2global_embedding.run.once_per_worker.once_per_worker import OncePerWorker
print('importing build-in modules')
import sys
from pathlib import Path
from typing import List
from runpy import run_path
from traceback import print_exception
from collections.abc import Iterable
from copy import copy
print('importing numpy')
import numpy as np
print('importing pytorch')
import torch
print('importing dask')
import dask
import dask.distributed
from dask.distributed import as_completed, Client, get_worker
print('importing log and progress modules')
from tqdm import tqdm
import enlighten
import logging
print('importing needed functions')
from local2global_embedding.run.utils import (ResultsDict, load_data, ScriptParser, patch_folder_name,
cluster_file_name, watch_progress,
load_classification_problem)
from local2global_embedding.utils import speye, set_device
from local2global_embedding.run.scripts import functions as func
from local2global_embedding.run.scripts.utils import build_patch, ScopedTemporaryFile
from functools import partialmethod
from tempfile import NamedTemporaryFile
[docs]
def with_dependencies(f):
def f_d(*args, _depends_on=None, **kwargs):
return f(*args, **kwargs)
f_d.__name__ = f.__name__
return f_d
[docs]
def run(name='Cora', data_root='/tmp', no_features=False, model='VGAE', num_epochs=10000,
patience=20, runs=10, cl_runs=5, dims: List[int] = None, hidden_multiplier=2, target_patch_degree=4.0,
min_overlap: int = None, target_overlap: int = None, gamma=0.0, sparsify='resistance', train_directed=False,
cluster='metis', num_clusters=10, beta=0.1, num_iters: int = None, lr=0.001, cl_model='logistic',
cl_train_args={}, cl_model_args={}, dist=False,
output='.', device: str = None, verbose_train=False, verbose_l2g=False, levels=1, resparsify=0,
run_baseline=True, normalise=False, restrict_lcc=False, scale=False, rotate=True, translate=True,
mmap_edges=False, mmap_features=False,
random_split=True, use_tmp=False, cluster_init=False, use_gpu_frac=1.0, grid_search_params=True,
progress_bars=True):
"""
Run training example.
By default this function writes results to the current working directory. To override this use the ``output``
argument.
Args:
name: Name of data set to load (one of {``'Cora'``, ``'PubMed'``, ``'AMZ_computers'``, ``'AMZ_photo'``})
data_root: Directory to use for downloaded data
no_features: If ``True``, discard features and use node identity.
model: embedding model type (one of {'VGAE', 'GAE', 'DGI'})
num_epochs: Number of training epochs
patience: Patience for early stopping
runs: Number of training runs (keep best result)
dims: list of embedding dimensions (default: ``[2]``)
hidden_multiplier: Hidden dimension is ``hidden_multiplier * dim``
target_patch_degree: Target patch degree for resistance sparsification.
min_overlap: Minimum target patch overlap (default: ``max(dims) + 1``)
target_overlap: Target patch overlap (default: ``2 * max(dims)``)
gamma: Value of 'gamma' for RMST sparsification
sparsify: Sparsification method to use (one of {``'resistance'``, ``'none'``, ``'rmst'``})
train_directed: Use the orignal directed network (only relevant for some loaders) (default: ``False``)
cluster: Clustering method to use (one of {``'louvain'``, ``'fennel'`` , ``'distributed'``, ``'metis'``})
num_clusters: Target number of clusters for distributed, fennel, or metis.
beta: Parameter for the distributed clustering algorithm
num_iters: Maximum iterations for distributed or fennel
lr: Learning rate
cl_model: the classification model to use (one of "logistic" or "mlp") (default: "logistic")
cl_train_args: extra arguments to pass down to the classification training (default: {})
cl_model_args: extra arguments to pass to the classfication model constructor (default: {})
dist: If ``True``, use distance decoder instead of inner product decoder
verbose_l2g: Verbose output for the alignment step (default: ``False``)
normalise: If True, normalise the dataset features (default: ``False``)
restrict_lcc: If True, restrict the dataset to only consider largest connected component (default: ``False``)
output: output folder
levels: number of hierarchical patch levels (default: 1)
resparsify: if > 0, use resistance sparsification for all levels of th hierarchy (default: 0)
scale: apply scaling transformations during alignment (default: False)
rotate: apply rotations during alignment (default: True)
translate: apply translations during alignment (default: True)
mmap_edges: use memory mapping for edges (only supported by some loaders) (default: False)
mmap_features: use memory mapping for features (only supported by some loaders) (default: False)
random_split: use random train-test splits for evaluation (default: True)
use_tmp: copy data to tmp dir during load (default: False)
cluster_init: run the cluster initialisation script (default: False)
use_gpu_frac: fraction of gpu to use by each worker (default: 1.0)
grid_search_params: use grid search for classification parameters (only for cl_model='mlp')
device: Device used for training e.g., 'cpu', 'cuda' (defaults to 'cuda' if available else 'cpu')
verbose_train: If ``True``, show progress info
run_baseline: if ``True``, run baseline full model
progress_bars: show progress bars (default: True)
This function is also exposed as a command-line interface.
.. rubric:: References
.. [#l2g] L. G. S. Jeub et al.
“Local2Global: Scaling global representation learning on graphs via local training”.
DLG-KDD’21. 2021. `arXiv:2107.12224 [cs.LG] <https://arxiv.org/abs/2107.12224>`_.
"""
if not progress_bars:
tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
grid_search_params = grid_search_params and cl_model == 'mlp' # grid search only implemented for MLP
if grid_search_params:
eval_func = func.mlp_grid_search_eval
else:
eval_func = func.evaluate
if verbose_train:
logging.basicConfig(level=logging.DEBUG)
if cluster_init:
print('setting up cluster')
cluster_init_path = Path().home() / '.config' / 'dask' / 'cluster_init.py'
kwargs = run_path(cluster_init_path)
client = Client(kwargs['cluster'])
else:
print('launching default client')
client = Client()
print(client.dashboard_link)
if 'gpu' in dask.config.get('distributed.worker.resources'):
gpu_req = {'gpu': use_gpu_frac}
else:
gpu_req = {}
if dims is None:
dims = [2]
output_folder = Path(output).expanduser()
data_root = Path(data_root).expanduser()
result_folder_name = f'{num_epochs=}_{patience=}_{lr=}'
print(f'Started experiment for data set {name}')
print(f'Results will be placed in {output_folder.resolve()}')
print(f'Data root is {data_root.resolve()}')
if use_tmp:
print('Any memmapped data will be moved to local storage.')
if normalise:
print('features will be normalised before training')
manager = enlighten.get_manager(threaded=True)
baseline_progress = manager.counter(desc='baseline', total=0, file=sys.stdout)
patch_progress = manager.counter(desc='patch', total=0, file=sys.stdout)
align_progress = manager.counter(desc='align', total=0, file=sys.stdout)
eval_progress = manager.counter(desc='eval', total=0, file=sys.stdout)
total_progress = manager.counter(desc='total', total=0, file=sys.stdout)
def progress_callback(bar):
bar.total += 1
total_progress.total += 1
def callback(future):
bar.update()
total_progress.update()
return callback
if train_directed:
train_basename = f'{name}_dir_{model}'
eval_basename = f'{name}_dir_{model}'
else:
train_basename = f'{name}_{model}'
eval_basename = f'{name}_{model}'
min_overlap = min_overlap if min_overlap is not None else max(dims) + 1
target_overlap = target_overlap if target_overlap is not None else 2 * max(dims)
if dist:
eval_basename += '_dist'
if model != 'DGI':
train_basename += '_dist'
if normalise:
eval_basename += '_norm'
train_basename += '_norm'
if grid_search_params:
eval_basename += '_gridsearch'
eval_basename += f'_{cl_model}'
if cl_model_args:
eval_basename += f'({cl_model_args})'
if cl_train_args:
eval_basename += f'{cl_train_args}'
l2g_name = 'l2g'
if scale:
l2g_name += '_scale'
if not rotate:
l2g_name += "_norotate"
if not translate:
l2g_name += "_notranslate"
if levels > 1:
l2g_name += f'_hc{levels}'
if isinstance(lr, Iterable):
lr = list(lr)
if len(lr) < runs:
if len(lr) == 2:
lr = np.logspace(np.log10(lr[0]), np.log10(lr[1]), runs)
else:
raise ValueError(f'Number of learning rates {len(lr)} specified does not match number of runs {runs}.')
else:
lr = [lr for _ in range(runs)]
all_tasks = as_completed()
def build_training_data(data, device):
device = set_device(device)
data = data.to(TGraph).to(device=device)
if no_features:
data.x = speye(data.num_nodes).to(device)
else:
data.x = torch.as_tensor(data.x, dtype=torch.float32)
if normalise:
r_sum = data.x.sum(dim=1)
r_sum[r_sum == 0] = 1.0 # avoid division by zero
data.x /= r_sum[:, None]
return data
@dask.delayed(pure=True, traverse=False)
def load_patch(patch_folder, data, i):
data = data._get_value()
nodes = np.load(patch_folder / f'patch{i}_index.npy')
return OncePerWorker.instance_for_function(lambda: build_training_data(data.subgraph(nodes, relabel=False), device))
def load_patch_data(patch_folder, data, n_patches):
return [load_patch(patch_folder, data, i) for i in range(n_patches)]
def load_and_copy_data():
data = load_data(name=name, root=data_root, restrict_lcc=restrict_lcc,
mmap_edges=mmap_edges, mmap_features=mmap_features, directed=train_directed)
if use_tmp:
tmpdir = Path(os.getenv("TMPDIR", "/tmp"))
e_file = tmpdir / f"{name}_edges.npy"
x_file = tmpdir / f"{name}_x.npy"
if isinstance(data.edge_index, np.memmap):
with FileLock(tmpdir/f"{name}_edges.lock"):
if not (e_file).is_file():
np.save(e_file, data.edge_index)
data.edge_index = np.load(e_file, mmap_mode='r')
if isinstance(data.x, np.memmap):
with FileLock(tmpdir/f"{name}_x.lock"):
if not x_file.is_file():
np.save(x_file, data.x)
data.x = np.load(x_file, mmap_mode='r')
cl_data = load_classification_problem(name, data, root=data_root)
data.cl_data = cl_data
return data
data = None
baseline_train_data = None
def _load_data():
nonlocal data
if data is None:
data = once_per_worker(load_and_copy_data)
@dask.delayed(pure=True, traverse=False)
def build_baseline_training_data(data):
data = copy(data._get_value())
return OncePerWorker.instance_for_function(lambda: build_training_data(data, device))
def _baseline_data():
nonlocal baseline_train_data
_load_data()
if baseline_train_data is None:
baseline_train_data = build_baseline_training_data(data)
if run_baseline:
# compute baseline full model if necessary
result_folder = output_folder / result_folder_name
result_folder.mkdir(exist_ok=True)
baseline_eval_file = result_folder / f'{eval_basename}_full_eval.json'
for d in dims:
with ResultsDict(baseline_eval_file, lock=False) as baseline_data:
r_done = baseline_data.runs(d)
coords = []
for r in range(r_done, runs):
_load_data()
_baseline_data()
baseline_info_file = result_folder / f'{train_basename}_d{d}_r{r}_full_info.json'
coords_task = dask.delayed(func.train, pure=False)(
data=baseline_train_data, model=model,
lr=lr[r], num_epochs=num_epochs,
patience=patience, verbose=verbose_train,
results_file=baseline_info_file, dim=d,
hidden_multiplier=hidden_multiplier,
dist=dist,
save_coords=mmap_features)
coords.append(coords_task)
eval_task = client.compute(dask.delayed(eval_func, pure=False)(
model=cl_model,
graph=data,
embedding=coords_task.coordinates,
results_file=baseline_eval_file,
dist=dist,
device=device,
train_args=cl_train_args,
model_args=cl_model_args,
runs=cl_runs,
random_split=random_split,
mmap_features=mmap_features,
use_tmp=use_tmp), resources=gpu_req)
eval_task.add_done_callback(progress_callback(baseline_progress))
all_tasks.add(eval_task)
del eval_task
del coords_task
patch_folder = output_folder / patch_folder_name(name, min_overlap, target_overlap, cluster, num_clusters,
num_iters, beta, levels, sparsify, target_patch_degree,
gamma)
cluster_file = output_folder / cluster_file_name(name, cluster, num_clusters, num_iters, beta, levels)
result_folder = patch_folder / result_folder_name
result_folder.mkdir(exist_ok=True, parents=True)
with SoftFileLock(patch_folder.with_suffix('.lock')):
pg_exists = (patch_folder / 'patch_graph.pt').is_file()
if not pg_exists:
_load_data()
patch_graph = dask.delayed(func.prepare_patches, pure=False)(
output_folder=output_folder, name=name, graph=data,
min_overlap=min_overlap, target_overlap=target_overlap,
cluster=cluster,
num_clusters=num_clusters, num_iters=num_iters, beta=beta, levels=levels,
sparsify=sparsify, target_patch_degree=target_patch_degree,
gamma=gamma,
verbose=False).persist()
patch_graph_initialised = True
num_patches = patch_graph.num_nodes.compute()
else:
num_patches = torch.load(patch_folder / 'patch_graph.pt').num_nodes
patch_graph = dask.delayed(torch.load, pure=True)(patch_folder / 'patch_graph.pt')
patch_graph_initialised = False
l2g_eval_file = result_folder / f'{eval_basename}_{l2g_name}_eval.json'
n_nodes = None
patch_data = None
for d in dims:
with ResultsDict(l2g_eval_file, lock=False) as res:
r_done = res.runs(d)
patch_runs = [[] for _ in range(num_patches)]
for r in range(r_done, runs):
_load_data()
if not patch_graph_initialised:
patch_graph = patch_graph.persist()
if patch_data is None:
patch_data = load_patch_data(patch_folder, data, num_patches)
patches = []
for pi in range(num_patches):
patch_node_file = patch_folder / f'patch{pi}_index.npy'
patch_result_file = result_folder / f'{train_basename}_patch{pi}_d{d}_r{r}_info.json'
patch = dask.delayed(func.train, nout=2)(data=patch_data[pi], model=model,
lr=lr[r], num_epochs=num_epochs,
patience=patience, verbose=verbose_train,
results_file=patch_result_file, dim=d,
hidden_multiplier=hidden_multiplier, dist=dist,
save_coords=mmap_features)
patches.append(patch)
del patch
l2g_coords_file = result_folder / f'{train_basename}_d{d}_r{r}_{l2g_name}_coords.npy'
if l2g_coords_file.is_file():
if mmap_features:
l2g_coords = l2g_coords_file
else:
l2g_coords = dask.delayed(np.load)(file=l2g_coords_file)
else:
if n_nodes is None:
n_nodes = data.num_nodes.compute()
shape = (n_nodes, d)
l2g_coords = func.hierarchical_l2g_align_patches(
patch_graph=patch_graph,
shape=shape,
scale=scale,
rotate=rotate,
translate=translate,
patches=patches,
mmap=mmap_features,
use_tmp=use_tmp,
verbose=verbose_l2g,
output_file=l2g_coords_file,
cluster_file=cluster_file,
resparsify=resparsify
)
coords_task = dask.delayed(eval_func, pure=False)(
model=cl_model,
graph=data,
embedding=l2g_coords,
results_file=l2g_eval_file,
dist=dist,
device=device,
train_args=cl_train_args,
model_args=cl_model_args,
runs=cl_runs,
random_split=random_split,
mmap_features=mmap_features,
use_tmp=use_tmp
)
coords_task = client.compute(coords_task, resources=gpu_req)
coords_task.add_done_callback(progress_callback(eval_progress))
all_tasks.add(coords_task)
del coords_task
del l2g_coords
baseline_progress.refresh()
patch_progress.refresh()
align_progress.refresh()
eval_progress.refresh()
total_progress.refresh()
# make sure to wait for all tasks to complete and report overall progress
watch_progress(all_tasks)
manager.stop()
if __name__ == '__main__':
print('launching main training script')
# run main script
ScriptParser(run, True).run()