# Copyright (c) 2021. Lucas G. S. Jeub
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import json
from pathlib import Path
from typing import Optional
import os
from collections.abc import Iterable
from time import perf_counter
import numpy as np
import torch
from atomicwrites import atomic_write
import local2global_embedding.embedding.gae as gae
import local2global_embedding.embedding.dgi as dgi
import local2global_embedding.embedding.train as training
from local2global_embedding.embedding.eval import reconstruction_auc
from local2global_embedding.network import TGraph
from local2global_embedding.utils import speye, set_device
from local2global_embedding.run.utils import ResultsDict, ScriptParser
from local2global import Patch
from local2global.patch import FilePatch
[docs]
def select_loss(model):
if isinstance(model, gae.VGAE):
return gae.VGAE_loss
elif isinstance(model, gae.GAE):
return gae.GAE_loss
elif isinstance(model, dgi.DGI):
return dgi.DGILoss()
[docs]
def create_model(model, dim, hidden_dim, num_features, dist):
if model == 'VGAE':
return gae.VGAE(dim, hidden_dim, num_features, dist)
elif model == 'GAE':
return gae.GAE(dim, hidden_dim, num_features, dist)
elif model == 'DGI':
return dgi.DGI(num_features, dim)
[docs]
class Count:
[docs]
def __init__(self):
self.count = 0
def __call__(self, *args, **kwargs):
self.count += 1
[docs]
def train(data, model, lr, num_epochs: int, patience: int, verbose: bool, results_file: str,
dim: int, hidden_multiplier: Optional[int] = None, dist=False, save_coords=False):
"""
train model on data
Args:
data: path to training data
model: str that will be evaluated to initialise the model
lr: learning rate
num_epochs: maximum number of training epochs
patience: early stopping patience
verbose: if True, print loss during training
results_file: json file of existing results
dist: use distance decoder for reconstruction
device: device to use for training (e.g., 'cuda', 'cpu')
"""
device = data.device
nodes = data.nodes
if isinstance(nodes, torch.Tensor):
nodes = nodes.cpu().numpy()
print(f'Launched training for model {model}_d{dim} with cuda devices {os.environ.get("CUDA_VISIBLE_DEVICES", "unavailable")} and device={device}')
results_file = Path(results_file)
model_file = results_file.with_name(results_file.stem.replace("_info", "_model") + ".pt")
coords_file = results_file.with_name(results_file.stem.replace("_info", "_coords") + ".npy")
model = create_model(model, dim, dim * hidden_multiplier, data.num_features, dist).to(device)
loss_fun = select_loss(model)
if results_file.exists():
with open(results_file) as f:
res = json.load(f)
if coords_file.exists():
if save_coords:
return FilePatch(nodes, str(coords_file))
else:
return Patch(nodes, np.load(coords_file))
else:
model.load_state_dict(torch.load(model_file))
model.eval()
coords = model.embed(data)
if save_coords:
np.save(coords_file, coords.cpu().numpy())
return FilePatch(nodes, str(coords_file))
else:
return Patch(nodes, coords.cpu().numpy())
else:
tic = perf_counter()
model.reset_parameters()
ep_count = Count()
model = training.train(data, model, loss_fun, num_epochs, patience, lr, verbose=verbose, logger=ep_count)
model.eval()
coords = model.embed(data)
toc = perf_counter()
auc = reconstruction_auc(coords, data, dist=dist)
loss = float(loss_fun(model, data))
torch.save(model.state_dict(), model_file)
if save_coords:
np.save(coords_file, coords.cpu().numpy())
with atomic_write(results_file, overwrite=True) as f: # this should avoid any chance of loosing existing data
json.dump({"dim": dim,
"loss": loss,
"auc": auc,
"train_time": toc-tic,
"tain_epochs": ep_count.count,
"args": {"lr": lr,
"num_epochs": num_epochs,
"patience": patience,
"dist": dist}
}, f)
if save_coords:
return FilePatch(nodes, str(coords_file))
else:
return Patch(nodes, coords.cpu().numpy())
if __name__ == '__main__':
ScriptParser(train).run()