#  Copyright (c) 2021. Lucas G. S. Jeub
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.
"""Training run script"""
import os
from filelock import SoftFileLock, FileLock
from local2global_embedding.network import TGraph
from local2global_embedding.run.once_per_worker import once_per_worker
from weakref import finalize
from local2global_embedding.run.once_per_worker.once_per_worker import OncePerWorker
print('importing build-in modules')
import sys
from pathlib import Path
from typing import List
from runpy import run_path
from traceback import print_exception
from collections.abc import Iterable
from copy import copy
print('importing numpy')
import numpy as np
print('importing pytorch')
import torch
print('importing dask')
import dask
import dask.distributed
from dask.distributed import as_completed, Client, get_worker
print('importing log and progress modules')
from tqdm import tqdm
import enlighten
import logging
print('importing needed functions')
from local2global_embedding.run.utils import (ResultsDict, load_data, ScriptParser, patch_folder_name,
                                              cluster_file_name, watch_progress,
                                              load_classification_problem)
from local2global_embedding.utils import speye, set_device
from local2global_embedding.run.scripts import functions as func
from local2global_embedding.run.scripts.utils import build_patch, ScopedTemporaryFile
from functools import partialmethod
from tempfile import NamedTemporaryFile
[docs]
def with_dependencies(f):
    def f_d(*args, _depends_on=None, **kwargs):
        return f(*args, **kwargs)
    f_d.__name__ = f.__name__
    return f_d 
[docs]
def run(name='Cora', data_root='/tmp', no_features=False, model='VGAE', num_epochs=10000,
        patience=20, runs=10, cl_runs=5, dims: List[int] = None, hidden_multiplier=2, target_patch_degree=4.0,
        min_overlap: int = None, target_overlap: int = None, gamma=0.0, sparsify='resistance', train_directed=False,
        cluster='metis', num_clusters=10, beta=0.1, num_iters: int = None, lr=0.001, cl_model='logistic',
        cl_train_args={}, cl_model_args={}, dist=False,
        output='.', device: str = None, verbose_train=False, verbose_l2g=False, levels=1, resparsify=0,
        run_baseline=True, normalise=False, restrict_lcc=False, scale=False, rotate=True, translate=True,
        mmap_edges=False, mmap_features=False,
        random_split=True, use_tmp=False, cluster_init=False, use_gpu_frac=1.0, grid_search_params=True,
        progress_bars=True):
    """
    Run training example.
    By default this function writes results to the current working directory. To override this use the ``output``
    argument.
    Args:
        name: Name of data set to load (one of {``'Cora'``, ``'PubMed'``, ``'AMZ_computers'``, ``'AMZ_photo'``})
        data_root: Directory to use for downloaded data
        no_features: If ``True``, discard features and use node identity.
        model: embedding model type (one of {'VGAE', 'GAE', 'DGI'})
        num_epochs: Number of training epochs
        patience: Patience for early stopping
        runs: Number of training runs (keep best result)
        dims: list of embedding dimensions (default: ``[2]``)
        hidden_multiplier: Hidden dimension is ``hidden_multiplier * dim``
        target_patch_degree: Target patch degree for resistance sparsification.
        min_overlap: Minimum target patch overlap (default: ``max(dims) + 1``)
        target_overlap: Target patch overlap (default: ``2 * max(dims)``)
        gamma: Value of 'gamma' for RMST sparsification
        sparsify: Sparsification method to use (one of {``'resistance'``, ``'none'``, ``'rmst'``})
        train_directed: Use the orignal directed network (only relevant for some loaders) (default: ``False``)
        cluster: Clustering method to use (one of {``'louvain'``, ``'fennel'`` , ``'distributed'``, ``'metis'``})
        num_clusters: Target number of clusters for distributed, fennel, or metis.
        beta: Parameter for the distributed clustering algorithm
        num_iters: Maximum iterations for distributed or fennel
        lr: Learning rate
        cl_model: the classification model to use (one of "logistic" or "mlp") (default: "logistic")
        cl_train_args: extra arguments to pass down to the classification training (default: {})
        cl_model_args: extra arguments to pass to the classfication model constructor (default: {})
        dist: If ``True``, use distance decoder instead of inner product decoder
        verbose_l2g: Verbose output for the alignment step (default: ``False``)
        normalise: If True, normalise the dataset features (default: ``False``)
        restrict_lcc: If True, restrict the dataset to only consider largest connected component (default: ``False``)
        output: output folder
        levels: number of hierarchical patch levels (default: 1)
        resparsify: if > 0, use resistance sparsification for all levels of th hierarchy (default: 0)
        scale: apply scaling transformations during alignment (default: False)
        rotate: apply rotations during alignment (default: True)
        translate: apply translations during alignment (default: True)
        mmap_edges: use memory mapping for edges (only supported by some loaders) (default: False)
        mmap_features: use memory mapping for features (only supported by some loaders) (default: False)
        random_split: use random train-test splits for evaluation (default: True)
        use_tmp: copy data to tmp dir during load (default: False)
        cluster_init: run the cluster initialisation script (default: False)
        use_gpu_frac: fraction of gpu to use by each worker (default: 1.0)
        grid_search_params: use grid search for classification parameters (only for cl_model='mlp')
        device: Device used for training e.g., 'cpu', 'cuda' (defaults to 'cuda' if available else 'cpu')
        verbose_train: If ``True``, show progress info
        run_baseline: if ``True``, run baseline full model
        progress_bars: show progress bars (default: True)
    This function is also exposed as a command-line interface.
    .. rubric:: References
    .. [#l2g] L. G. S. Jeub et al.
              “Local2Global: Scaling global representation learning on graphs via local training”.
              DLG-KDD’21. 2021. `arXiv:2107.12224 [cs.LG] <https://arxiv.org/abs/2107.12224>`_.
    """
    if not progress_bars:
        tqdm.__init__ = partialmethod(tqdm.__init__, disable=True)
    grid_search_params = grid_search_params and cl_model == 'mlp'  # grid search only implemented for MLP
    if grid_search_params:
        eval_func = func.mlp_grid_search_eval
    else:
        eval_func = func.evaluate
    if verbose_train:
        logging.basicConfig(level=logging.DEBUG)
    if cluster_init:
        print('setting up cluster')
        cluster_init_path = Path().home() / '.config' / 'dask' / 'cluster_init.py'
        kwargs = run_path(cluster_init_path)
        client = Client(kwargs['cluster'])
    else:
        print('launching default client')
        client = Client()
    print(client.dashboard_link)
    if 'gpu' in dask.config.get('distributed.worker.resources'):
        gpu_req = {'gpu': use_gpu_frac}
    else:
        gpu_req = {}
    if dims is None:
        dims = [2]
    output_folder = Path(output).expanduser()
    data_root = Path(data_root).expanduser()
    result_folder_name = f'{num_epochs=}_{patience=}_{lr=}'
    print(f'Started experiment for data set {name}')
    print(f'Results will be placed in {output_folder.resolve()}')
    print(f'Data root is {data_root.resolve()}')
    if use_tmp:
        print('Any memmapped data will be moved to local storage.')
    if normalise:
        print('features will be normalised before training')
    manager = enlighten.get_manager(threaded=True)
    baseline_progress = manager.counter(desc='baseline', total=0, file=sys.stdout)
    patch_progress = manager.counter(desc='patch', total=0, file=sys.stdout)
    align_progress = manager.counter(desc='align', total=0, file=sys.stdout)
    eval_progress = manager.counter(desc='eval', total=0, file=sys.stdout)
    total_progress = manager.counter(desc='total', total=0, file=sys.stdout)
    def progress_callback(bar):
        bar.total += 1
        total_progress.total += 1
        def callback(future):
            bar.update()
            total_progress.update()
        return callback
    if train_directed:
        train_basename = f'{name}_dir_{model}'
        eval_basename = f'{name}_dir_{model}'
    else:
        train_basename = f'{name}_{model}'
        eval_basename = f'{name}_{model}'
    min_overlap = min_overlap if min_overlap is not None else max(dims) + 1
    target_overlap = target_overlap if target_overlap is not None else 2 * max(dims)
    if dist:
        eval_basename += '_dist'
        if model != 'DGI':
            train_basename += '_dist'
    if normalise:
        eval_basename += '_norm'
        train_basename += '_norm'
    if grid_search_params:
        eval_basename += '_gridsearch'
    eval_basename += f'_{cl_model}'
    if cl_model_args:
        eval_basename += f'({cl_model_args})'
    if cl_train_args:
        eval_basename += f'{cl_train_args}'
    l2g_name = 'l2g'
    if scale:
        l2g_name += '_scale'
    if not rotate:
        l2g_name += "_norotate"
    if not translate:
        l2g_name += "_notranslate"
    if levels > 1:
        l2g_name += f'_hc{levels}'
    if isinstance(lr, Iterable):
        lr = list(lr)
        if len(lr) < runs:
            if len(lr) == 2:
                lr = np.logspace(np.log10(lr[0]), np.log10(lr[1]), runs)
            else:
                raise ValueError(f'Number of learning rates {len(lr)} specified does not match number of runs {runs}.')
    else:
        lr = [lr for _ in range(runs)]
    all_tasks = as_completed()
    def build_training_data(data, device):
        device = set_device(device)
        data = data.to(TGraph).to(device=device)
        if no_features:
            data.x = speye(data.num_nodes).to(device)
        else:
            data.x = torch.as_tensor(data.x, dtype=torch.float32)
            if normalise:
                r_sum = data.x.sum(dim=1)
                r_sum[r_sum == 0] = 1.0  # avoid division by zero
                data.x /= r_sum[:, None]
        return data
    @dask.delayed(pure=True, traverse=False)
    def load_patch(patch_folder, data, i):
        data = data._get_value()
        nodes = np.load(patch_folder / f'patch{i}_index.npy')
        return OncePerWorker.instance_for_function(lambda: build_training_data(data.subgraph(nodes, relabel=False), device))
    def load_patch_data(patch_folder, data, n_patches):
        return [load_patch(patch_folder, data, i) for i in range(n_patches)]
    def load_and_copy_data():
        data = load_data(name=name, root=data_root, restrict_lcc=restrict_lcc,
                         mmap_edges=mmap_edges, mmap_features=mmap_features, directed=train_directed)
        if use_tmp:
            tmpdir = Path(os.getenv("TMPDIR", "/tmp"))
            e_file = tmpdir / f"{name}_edges.npy"
            x_file = tmpdir / f"{name}_x.npy"
            if isinstance(data.edge_index, np.memmap):
                with FileLock(tmpdir/f"{name}_edges.lock"):
                    if not (e_file).is_file():
                        np.save(e_file, data.edge_index)
                    data.edge_index = np.load(e_file, mmap_mode='r')
            if isinstance(data.x, np.memmap):
                with FileLock(tmpdir/f"{name}_x.lock"):
                    if not x_file.is_file():
                        np.save(x_file, data.x)
                    data.x = np.load(x_file, mmap_mode='r')
        cl_data = load_classification_problem(name, data, root=data_root)
        data.cl_data = cl_data
        return data
    data = None
    baseline_train_data = None
    def _load_data():
        nonlocal data
        if data is None:
            data = once_per_worker(load_and_copy_data)
    @dask.delayed(pure=True, traverse=False)
    def build_baseline_training_data(data):
        data = copy(data._get_value())
        return OncePerWorker.instance_for_function(lambda: build_training_data(data, device))
    def _baseline_data():
        nonlocal baseline_train_data
        _load_data()
        if baseline_train_data is None:
            baseline_train_data = build_baseline_training_data(data)
    if run_baseline:
        # compute baseline full model if necessary
        result_folder = output_folder / result_folder_name
        result_folder.mkdir(exist_ok=True)
        baseline_eval_file = result_folder / f'{eval_basename}_full_eval.json'
        for d in dims:
            with ResultsDict(baseline_eval_file, lock=False) as baseline_data:
                r_done = baseline_data.runs(d)
            coords = []
            for r in range(r_done, runs):
                _load_data()
                _baseline_data()
                baseline_info_file = result_folder / f'{train_basename}_d{d}_r{r}_full_info.json'
                coords_task = dask.delayed(func.train, pure=False)(
                    data=baseline_train_data, model=model,
                    lr=lr[r], num_epochs=num_epochs,
                    patience=patience, verbose=verbose_train,
                    results_file=baseline_info_file, dim=d,
                    hidden_multiplier=hidden_multiplier,
                    dist=dist,
                    save_coords=mmap_features)
                coords.append(coords_task)
                eval_task = client.compute(dask.delayed(eval_func, pure=False)(
                    model=cl_model,
                    graph=data,
                    embedding=coords_task.coordinates,
                    results_file=baseline_eval_file,
                    dist=dist,
                    device=device,
                    train_args=cl_train_args,
                    model_args=cl_model_args,
                    runs=cl_runs,
                    random_split=random_split,
                    mmap_features=mmap_features,
                    use_tmp=use_tmp), resources=gpu_req)
                eval_task.add_done_callback(progress_callback(baseline_progress))
                all_tasks.add(eval_task)
                del eval_task
                del coords_task
    patch_folder = output_folder / patch_folder_name(name, min_overlap, target_overlap, cluster, num_clusters,
                                                     num_iters, beta, levels, sparsify, target_patch_degree,
                                                     gamma)
    cluster_file = output_folder / cluster_file_name(name, cluster, num_clusters, num_iters, beta, levels)
    result_folder = patch_folder / result_folder_name
    result_folder.mkdir(exist_ok=True, parents=True)
    with SoftFileLock(patch_folder.with_suffix('.lock')):
        pg_exists = (patch_folder / 'patch_graph.pt').is_file()
    if not pg_exists:
        _load_data()
        patch_graph = dask.delayed(func.prepare_patches, pure=False)(
            output_folder=output_folder, name=name, graph=data,
            min_overlap=min_overlap, target_overlap=target_overlap,
            cluster=cluster,
            num_clusters=num_clusters, num_iters=num_iters, beta=beta, levels=levels,
            sparsify=sparsify, target_patch_degree=target_patch_degree,
            gamma=gamma,
            verbose=False).persist()
        patch_graph_initialised = True
        num_patches = patch_graph.num_nodes.compute()
    else:
        num_patches = torch.load(patch_folder / 'patch_graph.pt').num_nodes
        patch_graph = dask.delayed(torch.load, pure=True)(patch_folder / 'patch_graph.pt')
        patch_graph_initialised = False
    l2g_eval_file = result_folder / f'{eval_basename}_{l2g_name}_eval.json'
    n_nodes = None
    patch_data = None
    for d in dims:
        with ResultsDict(l2g_eval_file, lock=False) as res:
            r_done = res.runs(d)
        patch_runs = [[] for _ in range(num_patches)]
        for r in range(r_done, runs):
            _load_data()
            if not patch_graph_initialised:
                patch_graph = patch_graph.persist()
            if patch_data is None:
                patch_data = load_patch_data(patch_folder, data, num_patches)
            patches = []
            for pi in range(num_patches):
                patch_node_file = patch_folder / f'patch{pi}_index.npy'
                patch_result_file = result_folder / f'{train_basename}_patch{pi}_d{d}_r{r}_info.json'
                patch = dask.delayed(func.train, nout=2)(data=patch_data[pi], model=model,
                                                  lr=lr[r], num_epochs=num_epochs,
                                                  patience=patience, verbose=verbose_train,
                                                  results_file=patch_result_file, dim=d,
                                                  hidden_multiplier=hidden_multiplier, dist=dist,
                                                  save_coords=mmap_features)
                patches.append(patch)
                del patch
            l2g_coords_file = result_folder / f'{train_basename}_d{d}_r{r}_{l2g_name}_coords.npy'
            if l2g_coords_file.is_file():
                if mmap_features:
                    l2g_coords = l2g_coords_file
                else:
                    l2g_coords = dask.delayed(np.load)(file=l2g_coords_file)
            else:
                if n_nodes is None:
                    n_nodes = data.num_nodes.compute()
                shape = (n_nodes, d)
                l2g_coords = func.hierarchical_l2g_align_patches(
                                         patch_graph=patch_graph,
                                         shape=shape,
                                         scale=scale,
                                         rotate=rotate,
                                         translate=translate,
                                         patches=patches,
                                         mmap=mmap_features,
                                         use_tmp=use_tmp,
                                         verbose=verbose_l2g,
                                         output_file=l2g_coords_file,
                                         cluster_file=cluster_file,
                                         resparsify=resparsify
                                        )
            coords_task = dask.delayed(eval_func, pure=False)(
                model=cl_model,
                graph=data,
                embedding=l2g_coords,
                results_file=l2g_eval_file,
                dist=dist,
                device=device,
                train_args=cl_train_args,
                model_args=cl_model_args,
                runs=cl_runs,
                random_split=random_split,
                mmap_features=mmap_features,
                use_tmp=use_tmp
            )
            coords_task = client.compute(coords_task, resources=gpu_req)
            coords_task.add_done_callback(progress_callback(eval_progress))
            all_tasks.add(coords_task)
            del coords_task
            del l2g_coords
    baseline_progress.refresh()
    patch_progress.refresh()
    align_progress.refresh()
    eval_progress.refresh()
    total_progress.refresh()
    # make sure to wait for all tasks to complete and report overall progress
    watch_progress(all_tasks)
    manager.stop() 
if __name__ == '__main__':
    print('launching main training script')
    # run main script
    ScriptParser(run, True).run()