Source code for local2global_embedding.run.scripts.evaluate

#  Copyright (c) 2021. Lucas G. S. Jeub
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.
from pathlib import Path
from statistics import mean, stdev

import torch
from copy import copy
import numpy as np
from numpy.lib.format import open_memmap
from typing import Optional

from local2global_embedding.embedding.eval import reconstruction_auc
from local2global_embedding.classfication import Logistic, train, accuracy, MLP
from local2global_embedding.run.utils import ResultsDict, ScriptParser, load_classification_problem, load_data
from .utils import ScopedTemporaryFile
from traceback import print_exc


[docs] def evaluate(graph, embedding, results_file: str, dist=False, device: Optional[str]=None, runs=50, train_args={}, mmap_features=False, random_split=False, model='logistic', model_args={}, use_tmp=False): try: train_args_default = dict(num_epochs=10000, patience=20, lr=0.01, batch_size=100000, alpha=0, beta=0, weight_decay=0) train_args_default.update(train_args) train_args = train_args_default mmap_mode = 'r' if mmap_features else None if isinstance(embedding, str) or isinstance(embedding, Path): coords = np.load(embedding, mmap_mode=mmap_mode) else: coords = np.asarray(embedding) print(f'evaluating with {runs} classification runs.') print('graph data loaded') cl_data = copy(graph.cl_data) print('classification problem loaded') num_labels = cl_data.num_labels print(f"{num_labels=}") if use_tmp and mmap_features: tmp_file = ScopedTemporaryFile(prefix='coords_', suffix='.npy') # path of temporary file that is automatically cleaned up when garbage-collected coords_tmp = open_memmap(tmp_file, mode='w+', dtype=coords.dtype, shape=coords.shape) coords_tmp[:] = coords[:] coords = coords_tmp print('features moved to tmp storage') print("adding embedding") cl_data.x = torch.tensor(coords, dtype=torch.float32) print("embedding converted to tensor") dim = coords.shape[1] print("computing auc") auc = reconstruction_auc(coords, graph, dist=dist) acc = [] model_str = model if model == 'logistic': def construct_model(): return Logistic(dim, num_labels, **model_args) elif model == 'mlp': if 'hidden_dim' in model_args: def construct_model(): return MLP(dim, output_dim=num_labels, **model_args) else: def construct_model(): return MLP(dim, dim, num_labels, **model_args) else: raise ValueError(f'unknown model type {model}') for _ in range(runs): if random_split: print("computing new train/test split") cl_data.resplit() print("constructing model") model = construct_model() print("model constructed") model = train(cl_data, model, device=device, **train_args) print("model trained") acc.append(accuracy(cl_data, model)) if torch.cuda.is_available(): print(f'Model accuracy: {acc[-1]}, max memory: {torch.cuda.max_memory_allocated()}, total available memory: {torch.cuda.get_device_properties(torch.cuda.current_device()).total_memory}') with ResultsDict(results_file, replace=False, lock=True) as results: results.update_dim(dim, auc=auc, acc=acc, model=model_str, train_args=train_args, model_args=model_args) except Exception: print_exc() raise
if __name__ == '__main__': parser = ScriptParser(evaluate) args, kwargs = parser.parse() evaluate(**kwargs)