Source code for local2global_embedding.run.scripts.mlp_grid_search_eval

#  Copyright (c) 2022. Lucas G. S. Jeub
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.
import torch
from pathlib import Path
from itertools import product
from collections.abc import Iterable

from local2global_embedding.embedding.eval import reconstruction_auc
from local2global_embedding.classfication import MLP, train, accuracy
from local2global_embedding.run.utils import get_or_init_client, ScriptParser, ResultsDict, load_data
import numpy as np
from dask import delayed, compute

from local2global_embedding.run.once_per_worker import once_per_worker
from local2global_embedding.run.scripts.utils import load_cl_data


@delayed
def train_task(data, model_args, results_file, batch_size=100, device=None, **train_args):
    results_file = Path(results_file)
    print(f'training MLP({model_args}) with parameters {train_args}')
    model = MLP(input_dim=data.num_features, output_dim=data.num_labels, **model_args)
    model = train(data, model, batch_size=batch_size, device=None, **train_args)
    val_acc = accuracy(data, model, batch_size=batch_size, mode='val')
    train_args['batch_size'] = batch_size
    with ResultsDict(results_file, lock=True) as results:
        if val_acc > results.max('val_acc', dim=data.num_features):
            torch.save(model, results_file.with_name(results_file.stem + f'_d{data.num_features}_bestclassifier.pt'))
        results.update_dim(data.num_features, val_acc=val_acc, model_args=model_args, train_args=train_args)
    print(f'MLP({model_args}) with parameters {train_args} achieved {val_acc=}')
    return val_acc


@delayed
def compute_auc(name, data_root, restrict_lcc, mmap_edges, coords, dist):
    graph = load_data(name, data_root, restrict_lcc=restrict_lcc, mmap_edges=mmap_edges, load_features=False)
    return reconstruction_auc(coords, graph, dist)


@delayed
def compute_test_acc(prob, model_file):
    return accuracy(prob, torch.load(model_file), mode='test')


def _clean_grid_args(args):
    grid = {}
    for key, val in args.items():
        if isinstance(val, Iterable):
            grid[key] = val
        else:
            grid[key] = (val,)
    return grid

def _make_grid(model_args, train_args):
    model_args = _clean_grid_args(model_args)
    train_args = _clean_grid_args(train_args)
    return [(dict(zip(model_args.keys(), vals)), dict(zip(train_args.keys(), tvals)))
            for tvals in product(*train_args.values())
            for vals in product(*model_args.values())]


[docs] def mlp_grid_search_eval(name, data_root, embedding_file, results_file, dist=False, model_args=None, train_args=None, mmap_edges=None, mmap_features=None, random_split=False, use_tmp=False, model='mlp', restrict_lcc=False, device=None, runs=None): """ Run grid search over MLP parameters Args: name: Name of data set data_root: Root folder for downloaded data embedding_file: File containing embedding coordinates (npy) results_file: File to store search results (json) train_args: grid of training arguments default ({'batch_size': (100000,), 'num_epochs': (1000,), 'patience': (20,), 'lr': (0.01, 0.001, 0.0001)}) mmap_features: if True use mmap to load features use_tmp: if True and using mmap, copy features to temporary storage model_args: grid of model parameters (default: kwargs = {'hidden_dim': (128, 256, 512, 1024), 'n_layers': (2, 3, 4), 'dropout': (0, 0.25, 0.5), 'batch_norm': (True,)}) Returns: dictionary of best model parameters """ # TODO implement multiple runs with random split results_file = Path(results_file) final_results_file = results_file.with_name(results_file.stem + '_best.json') dim = np.load(embedding_file, mmap_mode='r').shape[1] if model != 'mlp': raise NotImplementedError('grid search only implemented for MLP') with ResultsDict(final_results_file, lock=True) as best_results: if best_results.contains_dim(dim): return client = get_or_init_client() # launch distributed scheduler if run standalone model_grid = {'hidden_dim': (128, 256, 512, 1024), 'n_layers': (2, 3, 4), 'dropout': (0, 0.25, 0.5), 'batch_norm': (True,)} if model_args is not None: model_grid.update(model_args) train_grid = {'batch_size': (100000,), 'num_epochs': (1000,), 'patience': (20,), 'lr': (0.01, 0.001, 0.0001)} if train_args is not None: train_grid.update(train_args) arg_grid = _make_grid(model_grid, train_grid) if results_file.is_file(): with ResultsDict(results_file, lock=False) as results: runs = results.runs(dim) if runs < len(arg_grid): raise RuntimeError('Partial results already exist, resume not implemented, stopping.') else: arg_grid = [] prob = once_per_worker(lambda: load_cl_data(name, data_root, embedding_file, mmap_features, use_tmp, restrict_lcc=restrict_lcc)) task_list = [] for margs, targs in arg_grid: task_list.append(train_task(prob, margs, results_file, **targs)) auc = compute_auc(name, data_root, restrict_lcc, mmap_edges, prob.x, dist) task_list, auc = compute(task_list, auc) test_acc = compute_test_acc(prob, results_file.with_name(results_file.stem + f'_d{dim}_bestclassifier.pt')).compute() with ResultsDict(results_file) as results: with ResultsDict(final_results_file, lock=True, replace=True) as best_results: val_list = results.get('val_acc', dim=dim) i = np.argmax(val_list) best_model_args = results.get('model_args', dim=dim)[i] best_train_args = results.get('train_args', dim=dim)[i] best_results.update_dim(dim, auc=auc, test_acc=test_acc, best_model_args=best_model_args, best_train_args=best_train_args) print(f'best model is MLP({best_model_args} trained with {best_train_args}, {test_acc=}')
if __name__ == '__main__': ScriptParser(mlp_grid_search_eval).run()