Source code for local2global_embedding.utils

#  Copyright (c) 2021. Lucas G. S. Jeub
#
#  Permission is hereby granted, free of charge, to any person obtaining a copy
#  of this software and associated documentation files (the "Software"), to deal
#  in the Software without restriction, including without limitation the rights
#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
#  copies of the Software, and to permit persons to whom the Software is
#  furnished to do so, subject to the following conditions:
#
#  The above copyright notice and this permission notice shall be included in all
#  copies or substantial portions of the Software.
#
#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
#  SOFTWARE.
import torch
import torch.nn
from tempfile import TemporaryFile
from time import perf_counter


[docs] def speye(n, dtype=torch.float): """identity matrix of dimension n as sparse_coo_tensor.""" return torch.sparse_coo_tensor(torch.tile(torch.arange(n, dtype=torch.long), (2, 1)), torch.ones(n, dtype=dtype), (n, n))
[docs] def get_device(model: torch.nn.Module): return next(model.parameters()).device
[docs] def set_device(device): if device is None: if torch.cuda.is_available(): device = torch.device('cuda') else: device = torch.device('cpu') else: device = torch.device(device) return device
[docs] class EarlyStopping: """ Context manager for early stopping """
[docs] def __init__(self, patience, delta=0): """ Initialise early stopping context manager Args: patience: wait ``patience`` number of epochs without loss improvement before stopping delta: minimum improvement to consider significant (default: 0) """ self.patience = patience self.delta = delta self.best_loss = float('inf') self.count = 0 self._file = TemporaryFile()
def __enter__(self): self.best_loss = float('inf') self.count = 0 if self._file.closed: self._file = TemporaryFile() return self def __exit__(self, exc_type, exc_val, exc_tb): self._file.close() def _save_model(self, model): self._file.seek(0) torch.save(model.state_dict(), self._file) def _load_model(self, model: torch.nn.Module): self._file.seek(0) model.load_state_dict(torch.load(self._file)) def __call__(self, loss, model): """ check stopping criterion and save or restore model state as appropriate Args: loss: loss value for stopping model: Returns: ``True`` if training should be stopped, ``False`` otherwise """ loss = float(loss) # make sure no tensors used here to avoid propagating gradients if loss >= self.best_loss - self.delta: self.count += 1 else: self.count = 0 if loss < self.best_loss: self.best_loss = loss self._save_model(model) if self.count > self.patience: self._load_model(model) return True else: return False
[docs] class Timer: """ Context manager for accumulating execution time Adds the time taken within block to a running total. """
[docs] def __init__(self): self.total = 0.0
def __enter__(self): self.tic = perf_counter() return self def __exit__(self, exc_type, exc_value, exc_tb): self.total += perf_counter() - self.tic
[docs] def flatten(l, ltypes=(list, tuple)): if isinstance(l, ltypes): ltype = type(l) l = list(l) i = 0 while i < len(l): while isinstance(l[i], ltypes): if not l[i]: l.pop(i) i -= 1 break else: l[i:i + 1] = l[i] i += 1 return ltype(l) else: return l