#  Copyright (c) 2021. Lucas G. S. Jeub
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.
from pathlib import Path
from statistics import mean, stdev
import torch
from copy import copy
import numpy as np
from numpy.lib.format import open_memmap
from typing import Optional
from local2global_embedding.embedding.eval import reconstruction_auc
from local2global_embedding.classfication import Logistic, train, accuracy, MLP
from local2global_embedding.run.utils import ResultsDict, ScriptParser, load_classification_problem, load_data
from .utils import ScopedTemporaryFile
from traceback import print_exc
[docs]
def evaluate(graph, embedding, results_file: str, dist=False,
             device: Optional[str]=None, runs=50, train_args={}, mmap_features=False, random_split=False,
             model='logistic', model_args={}, use_tmp=False):
    try:
        train_args_default = dict(num_epochs=10000, patience=20, lr=0.01, batch_size=100000, alpha=0, beta=0, weight_decay=0)
        train_args_default.update(train_args)
        train_args = train_args_default
        mmap_mode = 'r' if mmap_features else None
        if isinstance(embedding, str) or isinstance(embedding, Path):
            coords = np.load(embedding, mmap_mode=mmap_mode)
        else:
            coords = np.asarray(embedding)
        print(f'evaluating with {runs} classification runs.')
        print('graph data loaded')
        cl_data = copy(graph.cl_data)
        print('classification problem loaded')
        num_labels = cl_data.num_labels
        print(f"{num_labels=}")
        if use_tmp and mmap_features:
            tmp_file = ScopedTemporaryFile(prefix='coords_', suffix='.npy')  # path of temporary file that is automatically cleaned up when garbage-collected
            coords_tmp = open_memmap(tmp_file, mode='w+', dtype=coords.dtype, shape=coords.shape)
            coords_tmp[:] = coords[:]
            coords = coords_tmp
            print('features moved to tmp storage')
        print("adding embedding")
        cl_data.x = torch.tensor(coords, dtype=torch.float32)
        print("embedding converted to tensor")
        dim = coords.shape[1]
        print("computing auc")
        auc = reconstruction_auc(coords, graph, dist=dist)
        acc = []
        model_str = model
        if model == 'logistic':
            def construct_model():
                return Logistic(dim, num_labels, **model_args)
        elif model == 'mlp':
            if 'hidden_dim' in model_args:
                def construct_model():
                    return MLP(dim, output_dim=num_labels, **model_args)
            else:
                def construct_model():
                    return MLP(dim, dim, num_labels, **model_args)
        else:
            raise ValueError(f'unknown model type {model}')
        for _ in range(runs):
            if random_split:
                print("computing new train/test split")
                cl_data.resplit()
            print("constructing model")
            model = construct_model()
            print("model constructed")
            model = train(cl_data, model, device=device, **train_args)
            print("model trained")
            acc.append(accuracy(cl_data, model))
            if torch.cuda.is_available():
                print(f'Model accuracy: {acc[-1]}, max memory: {torch.cuda.max_memory_allocated()}, total available memory: {torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory}')
        with ResultsDict(results_file, replace=False, lock=True) as results:
            results.update_dim(dim, auc=auc, acc=acc, model=model_str, train_args=train_args,
                               model_args=model_args)
    except Exception:
        print_exc()
        raise 
if __name__ == '__main__':
    parser = ScriptParser(evaluate)
    args, kwargs = parser.parse()
    evaluate(**kwargs)