Source code for local2global_embedding.run.scripts.train

#  Copyright (c) 2021. Lucas G. S. Jeub
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.
import json
from pathlib import Path
from typing import Optional
import os
from collections.abc import Iterable
from time import perf_counter

import numpy as np
import torch
from atomicwrites import atomic_write

import local2global_embedding.embedding.gae as gae
import local2global_embedding.embedding.dgi as dgi
import local2global_embedding.embedding.train as training
from local2global_embedding.embedding.eval import reconstruction_auc
from local2global_embedding.network import TGraph
from local2global_embedding.utils import speye, set_device
from local2global_embedding.run.utils import ResultsDict, ScriptParser
from local2global import Patch
from local2global.patch import FilePatch


[docs] def select_loss(model): if isinstance(model, gae.VGAE): return gae.VGAE_loss elif isinstance(model, gae.GAE): return gae.GAE_loss elif isinstance(model, dgi.DGI): return dgi.DGILoss()
[docs] def create_model(model, dim, hidden_dim, num_features, dist): if model == 'VGAE': return gae.VGAE(dim, hidden_dim, num_features, dist) elif model == 'GAE': return gae.GAE(dim, hidden_dim, num_features, dist) elif model == 'DGI': return dgi.DGI(num_features, dim)
[docs] class Count:
[docs] def __init__(self): self.count = 0
def __call__(self, *args, **kwargs): self.count += 1
[docs] def train(data, model, lr, num_epochs: int, patience: int, verbose: bool, results_file: str, dim: int, hidden_multiplier: Optional[int] = None, dist=False, save_coords=False): """ train model on data Args: data: path to training data model: str that will be evaluated to initialise the model lr: learning rate num_epochs: maximum number of training epochs patience: early stopping patience verbose: if True, print loss during training results_file: json file of existing results dist: use distance decoder for reconstruction device: device to use for training (e.g., 'cuda', 'cpu') """ device = data.device nodes = data.nodes if isinstance(nodes, torch.Tensor): nodes = nodes.cpu().numpy() print(f'Launched training for model {model}_d{dim} with cuda devices {os.environ.get("CUDA_VISIBLE_DEVICES", "unavailable")} and device={device}') results_file = Path(results_file) model_file = results_file.with_name(results_file.stem.replace("_info", "_model") + ".pt") coords_file = results_file.with_name(results_file.stem.replace("_info", "_coords") + ".npy") model = create_model(model, dim, dim * hidden_multiplier, data.num_features, dist).to(device) loss_fun = select_loss(model) if results_file.exists(): with open(results_file) as f: res = json.load(f) if coords_file.exists(): if save_coords: return FilePatch(nodes, str(coords_file)) else: return Patch(nodes, np.load(coords_file)) else: model.load_state_dict(torch.load(model_file)) model.eval() coords = model.embed(data) if save_coords: np.save(coords_file, coords.cpu().numpy()) return FilePatch(nodes, str(coords_file)) else: return Patch(nodes, coords.cpu().numpy()) else: tic = perf_counter() model.reset_parameters() ep_count = Count() model = training.train(data, model, loss_fun, num_epochs, patience, lr, verbose=verbose, logger=ep_count) model.eval() coords = model.embed(data) toc = perf_counter() auc = reconstruction_auc(coords, data, dist=dist) loss = float(loss_fun(model, data)) torch.save(model.state_dict(), model_file) if save_coords: np.save(coords_file, coords.cpu().numpy()) with atomic_write(results_file, overwrite=True) as f: # this should avoid any chance of loosing existing data json.dump({"dim": dim, "loss": loss, "auc": auc, "train_time": toc-tic, "tain_epochs": ep_count.count, "args": {"lr": lr, "num_epochs": num_epochs, "patience": patience, "dist": dist} }, f) if save_coords: return FilePatch(nodes, str(coords_file)) else: return Patch(nodes, coords.cpu().numpy())
if __name__ == '__main__': ScriptParser(train).run()