Source code for local2global_embedding.classfication

import torch
import torch.utils.data
import torch.utils
from hyperopt import fmin, hp, tpe, Trials, space_eval, rand
from hyperopt.pyll import scope
import pandas as pd
import numpy as np
from math import log, log2, ceil
from copy import deepcopy
from itertools import count, chain, product, groupby
from tqdm.auto import tqdm
from collections.abc import Sequence


from local2global_embedding.utils import EarlyStopping, get_device


[docs] class Logistic(torch.nn.Module):
[docs] def __init__(self, input_dim, output_dim, bias=True): super().__init__() self.linear = torch.nn.Linear(input_dim, output_dim, bias) self.softmax = torch.nn.LogSoftmax(dim=-1) self.reset_parameters()
[docs] def reset_parameters(self): torch.nn.init.xavier_uniform_(self.linear.weight.data) if self.linear.bias is not None: self.linear.bias.data.fill_(0.0)
[docs] def forward(self, x): return self.softmax(self.linear(x))
def _mlp_hidden_layer(in_dim, out_dim, batch_norm=True, dropout=0, relu_last=False): lin = torch.nn.Linear(in_dim, out_dim, bias=True) nl = torch.nn.ReLU() if batch_norm: bn = torch.nn.BatchNorm1d(out_dim) if relu_last: layer_list = (lin, bn, nl) else: layer_list = (lin, nl, bn) else: layer_list = (lin, nl) if dropout > 0: return *layer_list, torch.nn.Dropout(dropout) else: return layer_list
[docs] class MLP(torch.nn.Module):
[docs] def __init__(self, input_dim, hidden_dim, output_dim, n_layers=2, batch_norm=False, dropout=0, relu_last=False): super().__init__() self.network = torch.nn.Sequential(*_mlp_hidden_layer(input_dim, hidden_dim, batch_norm, dropout, relu_last), *chain.from_iterable(_mlp_hidden_layer(hidden_dim, hidden_dim, batch_norm, dropout, relu_last) for _ in range(n_layers-2)), torch.nn.Linear(hidden_dim, output_dim, bias=True), torch.nn.LogSoftmax(dim=-1)) self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim self.n_layers = n_layers self.reset_parameters()
[docs] def forward(self, x): return self.network(x)
[docs] def reset_parameters(self): for layer in self.network: if isinstance(layer, torch.nn.Linear): torch.nn.init.kaiming_normal_(layer.weight, nonlinearity='relu')
def __repr__(self): return f"{self.__class__.__name__}({self.hidden_dim}, {self.n_layers})"
[docs] class SNN(torch.nn.Module):
[docs] def __init__(self, input_dim, hidden_dim, output_dim, n_layers=2): super().__init__() self.network = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim, bias=False), torch.nn.SELU(), *chain.from_iterable((torch.nn.Linear(hidden_dim, hidden_dim, bias=False), torch.nn.SELU()) for _ in range(n_layers - 2)), torch.nn.Linear(hidden_dim, output_dim, bias=False)) self.input_dim = input_dim self.hidden_dim = hidden_dim self.output_dim = output_dim self.n_layers = n_layers self.reset_parameters()
[docs] def forward(self, x): return self.network(x)
[docs] def reset_parameters(self): for layer in self.network: if isinstance(layer, torch.nn.Linear): torch.nn.init.kaiming_normal_(layer.weight, nonlinearity='linear')
[docs] def random_split(y, num_train_per_class=20, num_val=500): split = {} num_classes = int(y.max().item()) + 1 train_mask = torch.zeros(y.size(), dtype=torch.bool) val_mask = torch.zeros(y.size(), dtype=torch.bool) test_mask = torch.zeros(y.size(), dtype=torch.bool) for c in range(num_classes): idx = (y == c).nonzero(as_tuple=False).view(-1) idx = idx[torch.randperm(idx.size(0))] idx = idx[:num_train_per_class] train_mask[idx] = True split['train'] = train_mask.nonzero().view(-1) remaining = (~train_mask).nonzero().view(-1) remaining = remaining[remaining >= 0] # only consider labelled data remaining = remaining[torch.randperm(remaining.size(0))] split['val'] = remaining[:num_val] split['test'] = remaining[num_val:] return split
[docs] class ClassificationProblem:
[docs] def __init__(self, y, x=None, split=None): self.y = y self.x = x self.num_labels = int(y.max()) + 1 if split is None: self.resplit() else: self.split = split self._val_data = None self._test_data = None self._train_data = None
[docs] def resplit(self, num_train_per_class=20, num_val=500): self.split = random_split(self.y, num_train_per_class=num_train_per_class, num_val=num_val)
@property def split(self): return {'train': self.train_index, 'val': self.val_index, 'test': self.test_index} @property def num_features(self): if self.x is None: return None else: return self.x.shape[1] @split.setter def split(self, split): self.train_index = split['train'] self.val_index = split['val'] self.test_index = split['test']
[docs] def training_data(self, include_unlabeled=False): if self.x is None: raise RuntimeError('Need to set embedding first') if include_unlabeled: y = torch.tensor(self.y) y[self.val_index] = -1 y[self.test_index] = -1 if isinstance(self.x, np.memmap): return MMapData(self.x, y) else: x = torch.as_tensor(self.x) return torch.utils.data.TensorDataset(x, y) else: if self._train_data is None: self._train_data = torch.utils.data.TensorDataset(torch.as_tensor(self.x[self.train_index, :]), torch.as_tensor(self.y[self.train_index])) return self._train_data
[docs] def validation_data(self): if self._val_data is None: self._val_data = torch.utils.data.TensorDataset(torch.as_tensor(self.x[self.val_index, :]), torch.as_tensor(self.y[self.val_index])) return self._val_data
[docs] def test_data(self): if self._test_data is None: self._test_data = torch.utils.data.TensorDataset(torch.as_tensor(self.x[self.test_index, :]), torch.as_tensor(self.y[self.test_index])) return self._test_data
[docs] def labeled_data(self): return torch.utils.data.TensorDataset(torch.as_tensor(self.x[self.y >= 0, :]), torch.as_tensor(self.y[self.y >= 0]))
[docs] def all_data(self): if isinstance(self.x, np.memmap): return MMapData(self.x, self.y) else: return torch.utils.data.TensorDataset(torch.as_tensor(self.x), torch.as_tensor(self.y))
[docs] class MMapData(torch.utils.data.Dataset):
[docs] def __init__(self, x, y): self.x = x self.y = y
def __len__(self): return len(self.y) def __getitem__(self, item): return torch.as_tensor(self.x[item]), torch.as_tensor(self.y[item])
[docs] class logger:
[docs] def __init__(self, data, model): self.loss = [] self.val_loss = [] self.test_loss = [] self.data = data self.model = model
def __call__(self, l): self.loss.append(l) self.val_loss.append(accuracy(self.data, self.model, mode='val')) self.test_loss.append(accuracy(self.data, self.model, mode='test'))
[docs] class VATloss(torch.nn.Module):
[docs] def __init__(self, epsilon, xi=1e-6, it=1): super().__init__() self.epsilon = epsilon self.xi = xi self.it = it self.divergence = torch.nn.KLDivLoss(reduction='batchmean', log_target=True)
[docs] def forward(self, model: torch.nn.Module, x, p=None): if p is None: with torch.no_grad(): p = model(x) with torch.no_grad(): r = torch.randn_like(x) r /= torch.norm(r, p=2, dim=-1, keepdim=True) p = p.detach() for _ in range(self.it): with torch.no_grad(): r = (self.xi * r).clone().detach() r.requires_grad = True model.zero_grad() d = self.divergence(model(x + r), p) d.backward() with torch.no_grad(): r.grad += 1e-16 r = (r.grad / torch.norm(r.grad, p=2, dim=-1, keepdim=True)) with torch.no_grad(): r = (self.epsilon * r).detach() model.zero_grad() div = self.divergence(model(x + r), p) # div = torch.sum(p * self.logsoftmax(model(x + r_adv))) return div
[docs] class EntMin(torch.nn.Module):
[docs] def forward(self, logits): return torch.mean(torch.distributions.Categorical(logits=logits).entropy(), dim=0)
[docs] class BatchedData(torch.utils.data.Dataset):
[docs] def __init__(self, data: torch.utils.data.TensorDataset, batch_size): self.data = data self.batch_size = batch_size
def __getitem__(self, item): index = item*self.batch_size return self.data[index:index+self.batch_size] def __len__(self): return len(range(0, len(self.data), self.batch_size))
[docs] def train(data: ClassificationProblem, model: torch.nn.Module, num_epochs, batch_size, lr=0.01, batch_logger=lambda loss: None, epoch_logger=lambda epoch: None, device=None, epsilon=1, alpha=0, beta=0, weight_decay=1e-2, decay_lr=False, xi=1e-6, vat_it=1, teacher_alpha=0, beta_1=0.9, beta_2=0.999, adam_epsilon=1e-8, patience=None): if device is None: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if alpha > 0 or beta > 0: train_data = data.training_data(include_unlabeled=True) else: train_data = data.training_data(include_unlabeled=False) model = model.to(device) if teacher_alpha: teacher = deepcopy(model) for param in teacher.parameters(): param.detach_() it_count = count() def update_teacher(): it = next(it_count) alpha = min(1 - 1 / (it + 1), teacher_alpha) for teacher_param, model_param in zip(teacher.parameters(), model.parameters()): teacher_param.data.mul_(alpha).add_(model_param, alpha=1 - alpha) else: teacher = None def update_teacher(): pass # optimizer = torch.optim.Adamax(model.parameters(), lr=lr, weight_decay=weight_decay) optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(beta_1, beta_2), eps=adam_epsilon) if batch_size < len(train_data): train_data = train_data.to(device) # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay) data_loader = torch.utils.data.DataLoader(BatchedData(train_data, batch_size=batch_size), batch_size=1, shuffle=True, collate_fn=lambda b: b[0], pin_memory=batch_size < len(train_data)) if decay_lr: lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(data_loader) * num_epochs) def step_lr(): lr_sched.step() else: def step_lr(): pass if patience is None: patience = float('inf') criterion = torch.nn.NLLLoss(reduction='mean', ignore_index=-1) vat_loss = VATloss(epsilon=epsilon, xi=xi, it=vat_it) ent_loss = EntMin() if alpha == 0: if beta == 0: def loss_fun(model, x, y): return criterion(model(x), y) else: def loss_fun(model, x, y): p = model(x) return criterion(p, y) + beta*ent_loss(p) else: if beta == 0: def loss_fun(model, x, y): p = model(x) return criterion(p, y) + alpha*vat_loss(model, x, p) else: def loss_fun(model, x, y): p = model(x) return criterion(p, y) + alpha*vat_loss(model, x, p) + beta*ent_loss(p) x_val, y_val = data.validation_data()[:] x_val = x_val.to(device=device, dtype=torch.float32) y_val = y_val.to(device=device) with EarlyStopping(patience) as stop: with tqdm(total=num_epochs, desc='training epoch') as progress: for e in range(num_epochs): model.train() for x, y in data_loader: x = x.to(device=device, dtype=torch.float32, non_blocking=True).view(-1, x.size(-1)) y = y.to(device=device, non_blocking=True).view(-1) optimizer.zero_grad() loss = loss_fun(model, x, y) loss.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm) optimizer.step() step_lr() update_teacher() batch_logger(float(loss)) epoch_logger(e) model.eval() vl = criterion(model(x_val), y_val) progress.update() # progress.write(f'validation loss: {vl}') if stop(vl, model): print(f'early stopping at epoch {e}') break return model
[docs] def predict(x, model: torch.nn.Module): eval_state = model.training model.eval() x = x.to(device=get_device(model), dtype=torch.float32) with torch.no_grad(): labels = torch.argmax(model(x), dim=-1) model.training = eval_state return labels
[docs] def accuracy(data: ClassificationProblem, model: torch.nn.Module, mode='test', batch_size=None): if mode == 'test': data = data.test_data() elif mode == 'val': data = data.validation_data() elif mode == 'all': data = data.labeled_data() else: raise ValueError(f'unknown mode {mode}') if batch_size is None: batch_size = len(data) loader = torch.utils.data.DataLoader(data, batch_size) val = 0 with torch.no_grad(): for x, y in loader: x = x.to(device=get_device(model), dtype=torch.float32) y = y.to(get_device(model)) val += torch.sum(predict(x, model) == y).cpu().item() return val / len(data)
[docs] def validation_accuracy(data, model: torch.nn.Module, batch_size=None): return accuracy(data, model, mode='val', batch_size=batch_size)
[docs] class HyperTuneObjective:
[docs] def __init__(self, data, n_tries=1, **kwargs): self.data = data self.args = kwargs self.min_loss = float('inf') self.best_parameters = deepcopy(self.args) self.n_tries = n_tries
def __call__(self, args): cum_loss = 0 for _ in range(self.n_tries): model = train(self.data, **args, **self.args) loss = 1 - validation_accuracy(self.data, model) cum_loss += loss if loss < self.min_loss: self.min_loss = loss self.best_parameters.update(deepcopy(args)) model.reset_parameters() return cum_loss / self.n_tries
[docs] @scope.define def mlp_model(*args, **kwargs): return MLP(*args, **kwargs)
[docs] @scope.define def linear_model(in_dim, out_dim): return torch.nn.Linear(in_dim, out_dim)
[docs] @scope.define def snn_model(in_dim, hidden_dim, out_dim, n_layers): return SNN(in_dim, hidden_dim, out_dim, n_layers)
def _make_space(args): log_vars = {'epsilon', 'weight_decay', 'xi', 'alpha', 'beta', 'lr'} int_vars = {'n_layers'} uniform_vars = {'dropout'} choice_vars = {'hidden', 'batch_norm'} space = {} fixed = {} for key, val in args.items(): if isinstance(val, Sequence): if key in log_vars: space[key] = hp.loguniform(key, log(min(val)), log(max(val))) if key in int_vars: space[key] = hp.uniformint(key, min(val), max(val)) if key in uniform_vars: space[key] = hp.uniform(key, min(val), max(val)) if key in choice_vars: space[key] = hp.choice(key, val) else: space[key] = val return space
[docs] def hyper_tune(data: ClassificationProblem, max_evals=100, n_tries=1, random_search=False, model_args=None, train_args=None): _model_args = {'hidden_dim': (128, 256, 512, 1024), 'n_layers': (2, 4), 'dropout': (0, 1), 'batch_norm': (False, True)} _train_args = {'batch_size': 100000, 'num_epochs': 1000, 'patience': 20, 'lr': (1e-4, 1e-1)} model_space = _make_space(_model_args) train_space = _make_space(_train_args) objective = HyperTuneObjective(data, n_tries=n_tries) trials = Trials() in_dim = data.num_features out_dim = data.num_labels search_space = { **train_space, 'model': scope.mlp_model(in_dim=in_dim, out_dim=out_dim, **model_space, ) } def transform(space, value): value = space_eval(space, {key: val[0] for key, val in value.items() if val}) if isinstance(value, torch.nn.Module): value = repr(value) return value if random_search: args = fmin(fn=objective, space=search_space, algo=rand.suggest, max_evals=max_evals, trials=trials) else: args = fmin(fn=objective, space=search_space, algo=tpe.suggest, max_evals=max_evals, trials=trials) params = {**train_space, **model_space} results = pd.DataFrame.from_records({key: transform(params[key], t['misc']['vals']) for key in t['misc']['vals']} for t in trials.trials) results['loss'] = trials.losses() best_args = space_eval(params, args) best_model_args = {key: best_args[key] for key in _model_args} best_train_args = {key: best_args[key] for key in _train_args} return best_model_args, best_train_args, results
[docs] def plot_hyper_results(results, plot_kws=None, diag_kws=None, **kwargs): import seaborn as sns import matplotlib.pyplot as plt plot_kws = {'size': 3, **(plot_kws if plot_kws is not None else {})} diag_kws = {'multiple': 'stack', **(diag_kws if diag_kws is not None else {})} if 'model' in results.columns: def key_fun(x): return [int("".join(val)) if key else "".join(val) for key, val in groupby(x, key=lambda x: x.isdigit())] results.model = pd.Categorical(results.model) results.model = results.model.cat.reorder_categories(sorted(results.model.cat.categories, key=key_fun), ordered=True) f = sns.PairGrid(results, y_vars='loss', hue='model', **kwargs) else: f = sns.PairGrid(results, y_vars='loss', **kwargs) log_axes = {'epsilon', 'lr', 'weight_decay', 'alpha', 'xi', 'adam_epsilon', 'beta'} label_map = {'epsilon': r'$\epsilon$', 'lr': r'$\lambda$', 'weight_decay': r'$\omega$', 'alpha': r'$\alpha$', 'xi': r'$\xi$', 'hidden': r'$h$', 'beta': r'$\beta$', 'loss': 'validation error'} log2_axes = {} cat_axes = {'hidden'} def plot_fun(*args, **kwargs): ax = plt.gca() label = ax.get_xlabel() if label in cat_axes: sns.swarmplot(*args, **kwargs) else: kwargs = kwargs.copy() size = kwargs.get('s', kwargs['size']) kwargs['s'] = size ** 2 del kwargs['size'] sns.scatterplot(*args, **kwargs) # if ax.get_xlabel() == 'hidden': # ax.set_xticks(hidden_vals) # ax.set_xticklabels(hidden_vals) # ax.set_xlim(0.9 * hidden_vals.min(), hidden_vals.max() / 0.9) # ax.minorticks_off() f.map_offdiag(plot_fun, **plot_kws) f.map_diag(sns.histplot, **diag_kws) for axs in f.axes: for ax in axs: xlabel = ax.get_xlabel() if xlabel in log_axes: ax.set_xscale('log') if xlabel in log2_axes: ax.set_xscale('log', base=2) ax.xaxis.set_major_formatter('{x:.0f}') if xlabel in label_map: ax.set_xlabel(label_map[xlabel]) ylabel = ax.get_ylabel() if ylabel in label_map: ax.set_ylabel(label_map[ylabel]) return f