import torch
import torch.utils.data
import torch.utils
from hyperopt import fmin, hp, tpe, Trials, space_eval, rand
from hyperopt.pyll import scope
import pandas as pd
import numpy as np
from math import log, log2, ceil
from copy import deepcopy
from itertools import count, chain, product, groupby
from tqdm.auto import tqdm
from collections.abc import Sequence
from local2global_embedding.utils import EarlyStopping, get_device
[docs]
class Logistic(torch.nn.Module):
[docs]
    def __init__(self, input_dim, output_dim, bias=True):
        super().__init__()
        self.linear = torch.nn.Linear(input_dim, output_dim, bias)
        self.softmax = torch.nn.LogSoftmax(dim=-1)
        self.reset_parameters() 
[docs]
    def reset_parameters(self):
        torch.nn.init.xavier_uniform_(self.linear.weight.data)
        if self.linear.bias is not None:
            self.linear.bias.data.fill_(0.0) 
[docs]
    def forward(self, x):
        return self.softmax(self.linear(x)) 
 
def _mlp_hidden_layer(in_dim, out_dim, batch_norm=True, dropout=0, relu_last=False):
    lin = torch.nn.Linear(in_dim, out_dim, bias=True)
    nl = torch.nn.ReLU()
    if batch_norm:
        bn = torch.nn.BatchNorm1d(out_dim)
        if relu_last:
            layer_list = (lin, bn, nl)
        else:
            layer_list = (lin, nl, bn)
    else:
        layer_list = (lin, nl)
    if dropout > 0:
        return *layer_list, torch.nn.Dropout(dropout)
    else:
        return layer_list
[docs]
class MLP(torch.nn.Module):
[docs]
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers=2, batch_norm=False, dropout=0, relu_last=False):
        super().__init__()
        self.network = torch.nn.Sequential(*_mlp_hidden_layer(input_dim, hidden_dim, batch_norm, dropout, relu_last),
                                           *chain.from_iterable(_mlp_hidden_layer(hidden_dim, hidden_dim, batch_norm,
                                                                                  dropout, relu_last)
                                                                for _ in range(n_layers-2)),
                                           torch.nn.Linear(hidden_dim, output_dim, bias=True),
                                           torch.nn.LogSoftmax(dim=-1))
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.n_layers = n_layers
        self.reset_parameters() 
[docs]
    def forward(self, x):
        return self.network(x) 
[docs]
    def reset_parameters(self):
        for layer in self.network:
            if isinstance(layer, torch.nn.Linear):
                torch.nn.init.kaiming_normal_(layer.weight, nonlinearity='relu') 
    def __repr__(self):
        return f"{self.__class__.__name__}({self.hidden_dim}, {self.n_layers})" 
[docs]
class SNN(torch.nn.Module):
[docs]
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers=2):
        super().__init__()
        self.network = torch.nn.Sequential(torch.nn.Linear(input_dim, hidden_dim, bias=False), torch.nn.SELU(),
                                           *chain.from_iterable((torch.nn.Linear(hidden_dim, hidden_dim, bias=False),
                                                                 torch.nn.SELU()) for _ in range(n_layers - 2)),
                                           torch.nn.Linear(hidden_dim, output_dim, bias=False))
        self.input_dim = input_dim
        self.hidden_dim = hidden_dim
        self.output_dim = output_dim
        self.n_layers = n_layers
        self.reset_parameters() 
[docs]
    def forward(self, x):
        return self.network(x) 
[docs]
    def reset_parameters(self):
        for layer in self.network:
            if isinstance(layer, torch.nn.Linear):
                torch.nn.init.kaiming_normal_(layer.weight, nonlinearity='linear') 
 
[docs]
def random_split(y, num_train_per_class=20, num_val=500):
    split = {}
    num_classes = int(y.max().item()) + 1
    train_mask = torch.zeros(y.size(), dtype=torch.bool)
    val_mask = torch.zeros(y.size(), dtype=torch.bool)
    test_mask = torch.zeros(y.size(), dtype=torch.bool)
    for c in range(num_classes):
        idx = (y == c).nonzero(as_tuple=False).view(-1)
        idx = idx[torch.randperm(idx.size(0))]
        idx = idx[:num_train_per_class]
        train_mask[idx] = True
    split['train'] = train_mask.nonzero().view(-1)
    remaining = (~train_mask).nonzero().view(-1)
    remaining = remaining[remaining >= 0]  # only consider labelled data
    remaining = remaining[torch.randperm(remaining.size(0))]
    split['val'] = remaining[:num_val]
    split['test'] = remaining[num_val:]
    return split 
[docs]
class ClassificationProblem:
[docs]
    def __init__(self, y, x=None, split=None):
        self.y = y
        self.x = x
        self.num_labels = int(y.max()) + 1
        if split is None:
            self.resplit()
        else:
            self.split = split
        self._val_data = None
        self._test_data = None
        self._train_data = None 
[docs]
    def resplit(self, num_train_per_class=20, num_val=500):
        self.split = random_split(self.y, num_train_per_class=num_train_per_class, num_val=num_val) 
    @property
    def split(self):
        return {'train': self.train_index, 'val': self.val_index, 'test': self.test_index}
    @property
    def num_features(self):
        if self.x is None:
            return None
        else:
            return self.x.shape[1]
    @split.setter
    def split(self, split):
        self.train_index = split['train']
        self.val_index = split['val']
        self.test_index = split['test']
[docs]
    def training_data(self, include_unlabeled=False):
        if self.x is None:
            raise RuntimeError('Need to set embedding first')
        if include_unlabeled:
            y = torch.tensor(self.y)
            y[self.val_index] = -1
            y[self.test_index] = -1
            if isinstance(self.x, np.memmap):
                return MMapData(self.x, y)
            else:
                x = torch.as_tensor(self.x)
                return torch.utils.data.TensorDataset(x, y)
        else:
            if self._train_data is None:
                self._train_data = torch.utils.data.TensorDataset(torch.as_tensor(self.x[self.train_index, :]),
                                                  torch.as_tensor(self.y[self.train_index]))
            return self._train_data 
[docs]
    def validation_data(self):
        if self._val_data is None:
            self._val_data = torch.utils.data.TensorDataset(torch.as_tensor(self.x[self.val_index, :]),
                                              torch.as_tensor(self.y[self.val_index]))
        return self._val_data 
[docs]
    def test_data(self):
        if self._test_data is None:
            self._test_data = torch.utils.data.TensorDataset(torch.as_tensor(self.x[self.test_index, :]),
                                              torch.as_tensor(self.y[self.test_index]))
        return self._test_data 
[docs]
    def labeled_data(self):
        return torch.utils.data.TensorDataset(torch.as_tensor(self.x[self.y >= 0, :]),
                                              torch.as_tensor(self.y[self.y >= 0])) 
[docs]
    def all_data(self):
        if isinstance(self.x, np.memmap):
            return MMapData(self.x, self.y)
        else:
            return torch.utils.data.TensorDataset(torch.as_tensor(self.x), torch.as_tensor(self.y)) 
 
[docs]
class MMapData(torch.utils.data.Dataset):
[docs]
    def __init__(self, x, y):
        self.x = x
        self.y = y 
    def __len__(self):
        return len(self.y)
    def __getitem__(self, item):
        return torch.as_tensor(self.x[item]), torch.as_tensor(self.y[item]) 
[docs]
class logger:
[docs]
    def __init__(self, data, model):
        self.loss = []
        self.val_loss = []
        self.test_loss = []
        self.data = data
        self.model = model 
    def __call__(self, l):
        self.loss.append(l)
        self.val_loss.append(accuracy(self.data, self.model, mode='val'))
        self.test_loss.append(accuracy(self.data, self.model, mode='test')) 
[docs]
class VATloss(torch.nn.Module):
[docs]
    def __init__(self, epsilon, xi=1e-6, it=1):
        super().__init__()
        self.epsilon = epsilon
        self.xi = xi
        self.it = it
        self.divergence = torch.nn.KLDivLoss(reduction='batchmean', log_target=True) 
[docs]
    def forward(self, model: torch.nn.Module, x, p=None):
        if p is None:
            with torch.no_grad():
                p = model(x)
        with torch.no_grad():
            r = torch.randn_like(x)
            r /= torch.norm(r, p=2, dim=-1, keepdim=True)
            p = p.detach()
        for _ in range(self.it):
            with torch.no_grad():
                r = (self.xi * r).clone().detach()
            r.requires_grad = True
            model.zero_grad()
            d = self.divergence(model(x + r), p)
            d.backward()
            with torch.no_grad():
                r.grad += 1e-16
                r = (r.grad / torch.norm(r.grad, p=2, dim=-1, keepdim=True))
        with torch.no_grad():
            r = (self.epsilon * r).detach()
        model.zero_grad()
        div = self.divergence(model(x + r), p)
        # div = torch.sum(p * self.logsoftmax(model(x + r_adv)))
        return div 
 
[docs]
class EntMin(torch.nn.Module):
[docs]
    def forward(self, logits):
        return torch.mean(torch.distributions.Categorical(logits=logits).entropy(), dim=0) 
 
[docs]
class BatchedData(torch.utils.data.Dataset):
[docs]
    def __init__(self, data: torch.utils.data.TensorDataset, batch_size):
        self.data = data
        self.batch_size = batch_size 
    def __getitem__(self, item):
        index = item*self.batch_size
        return self.data[index:index+self.batch_size]
    def __len__(self):
        return len(range(0, len(self.data), self.batch_size)) 
[docs]
def train(data: ClassificationProblem, model: torch.nn.Module, num_epochs, batch_size, lr=0.01, batch_logger=lambda loss: None,
          epoch_logger=lambda epoch: None, device=None, epsilon=1, alpha=0, beta=0, weight_decay=1e-2, decay_lr=False, xi=1e-6,
          vat_it=1,
          teacher_alpha=0, beta_1=0.9, beta_2=0.999, adam_epsilon=1e-8, patience=None):
    if device is None:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    if alpha > 0 or beta > 0:
        train_data = data.training_data(include_unlabeled=True)
    else:
        train_data = data.training_data(include_unlabeled=False)
    model = model.to(device)
    if teacher_alpha:
        teacher = deepcopy(model)
        for param in teacher.parameters():
            param.detach_()
        it_count = count()
        def update_teacher():
            it = next(it_count)
            alpha = min(1 - 1 / (it + 1), teacher_alpha)
            for teacher_param, model_param in zip(teacher.parameters(), model.parameters()):
                teacher_param.data.mul_(alpha).add_(model_param, alpha=1 - alpha)
    else:
        teacher = None
        def update_teacher():
            pass
    # optimizer = torch.optim.Adamax(model.parameters(), lr=lr, weight_decay=weight_decay)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay, betas=(beta_1, beta_2),
                                 eps=adam_epsilon)
    if batch_size < len(train_data):
        train_data = train_data.to(device)
    # optimizer = torch.optim.SGD(model.parameters(), lr=lr, weight_decay=weight_decay)
    data_loader = torch.utils.data.DataLoader(BatchedData(train_data, batch_size=batch_size), batch_size=1,
                                              shuffle=True, collate_fn=lambda b: b[0], pin_memory=batch_size < len(train_data))
    if decay_lr:
        lr_sched = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, len(data_loader) * num_epochs)
        def step_lr():
            lr_sched.step()
    else:
        def step_lr():
            pass
    if patience is None:
        patience = float('inf')
    criterion = torch.nn.NLLLoss(reduction='mean', ignore_index=-1)
    vat_loss = VATloss(epsilon=epsilon, xi=xi, it=vat_it)
    ent_loss = EntMin()
    if alpha == 0:
        if beta == 0:
            def loss_fun(model, x, y):
                return criterion(model(x), y)
        else:
            def loss_fun(model, x, y):
                p = model(x)
                return criterion(p, y) + beta*ent_loss(p)
    else:
        if beta == 0:
            def loss_fun(model, x, y):
                p = model(x)
                return criterion(p, y) + alpha*vat_loss(model, x, p)
        else:
            def loss_fun(model, x, y):
                p = model(x)
                return criterion(p, y) + alpha*vat_loss(model, x, p) + beta*ent_loss(p)
    x_val, y_val = data.validation_data()[:]
    x_val = x_val.to(device=device, dtype=torch.float32)
    y_val = y_val.to(device=device)
    with EarlyStopping(patience) as stop:
        with tqdm(total=num_epochs, desc='training epoch') as progress:
            for e in range(num_epochs):
                model.train()
                for x, y in data_loader:
                    x = x.to(device=device, dtype=torch.float32, non_blocking=True).view(-1, x.size(-1))
                    y = y.to(device=device, non_blocking=True).view(-1)
                    optimizer.zero_grad()
                    loss = loss_fun(model, x, y)
                    loss.backward()
                    # torch.nn.utils.clip_grad_norm_(model.parameters(), clip_norm)
                    optimizer.step()
                    step_lr()
                    update_teacher()
                    batch_logger(float(loss))
                epoch_logger(e)
                model.eval()
                vl = criterion(model(x_val), y_val)
                progress.update()
                # progress.write(f'validation loss: {vl}')
                if stop(vl, model):
                    print(f'early stopping at epoch {e}')
                    break
    return model 
[docs]
def predict(x, model: torch.nn.Module):
    eval_state = model.training
    model.eval()
    x = x.to(device=get_device(model), dtype=torch.float32)
    with torch.no_grad():
        labels = torch.argmax(model(x), dim=-1)
    model.training = eval_state
    return labels 
[docs]
def accuracy(data: ClassificationProblem, model: torch.nn.Module, mode='test', batch_size=None):
    if mode == 'test':
        data = data.test_data()
    elif mode == 'val':
        data = data.validation_data()
    elif mode == 'all':
        data = data.labeled_data()
    else:
        raise ValueError(f'unknown mode {mode}')
    if batch_size is None:
        batch_size = len(data)
    loader = torch.utils.data.DataLoader(data, batch_size)
    val = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device=get_device(model), dtype=torch.float32)
            y = y.to(get_device(model))
            val += torch.sum(predict(x, model) == y).cpu().item()
    return val / len(data) 
[docs]
def validation_accuracy(data, model: torch.nn.Module, batch_size=None):
    return accuracy(data, model, mode='val', batch_size=batch_size) 
[docs]
class HyperTuneObjective:
[docs]
    def __init__(self, data, n_tries=1, **kwargs):
        self.data = data
        self.args = kwargs
        self.min_loss = float('inf')
        self.best_parameters = deepcopy(self.args)
        self.n_tries = n_tries 
    def __call__(self, args):
        cum_loss = 0
        for _ in range(self.n_tries):
            model = train(self.data, **args, **self.args)
            loss = 1 - validation_accuracy(self.data, model)
            cum_loss += loss
            if loss < self.min_loss:
                self.min_loss = loss
                self.best_parameters.update(deepcopy(args))
            model.reset_parameters()
        return cum_loss / self.n_tries 
[docs]
@scope.define
def mlp_model(*args, **kwargs):
    return MLP(*args, **kwargs) 
[docs]
@scope.define
def linear_model(in_dim, out_dim):
    return torch.nn.Linear(in_dim, out_dim) 
[docs]
@scope.define
def snn_model(in_dim, hidden_dim, out_dim, n_layers):
    return SNN(in_dim, hidden_dim, out_dim, n_layers) 
[docs]
def grid_search(data, param_grid, epochs=10, batch_size=100, param_transform=lambda args: args, **kwargs):
    objective = HyperTuneObjective(data, **kwargs)
    results = []
    total = 1
    for v in param_grid.values():
        total *= len(v)
    for params in tqdm(product(*param_grid.values()), total=total, desc='grid search'):
        args = dict(zip(param_grid.keys(), params))
        args['model'].reset_parameters()
        args['loss'] = objective(args)
        args = param_transform(args)
        results.append(args)
    return objective.best_model, objective.best_parameters, pd.DataFrame.from_records(results) 
def _make_space(args):
    log_vars = {'epsilon', 'weight_decay', 'xi', 'alpha', 'beta', 'lr'}
    int_vars = {'n_layers'}
    uniform_vars = {'dropout'}
    choice_vars = {'hidden', 'batch_norm'}
    space = {}
    fixed = {}
    for key, val in args.items():
        if isinstance(val, Sequence):
            if key in log_vars:
                space[key] = hp.loguniform(key, log(min(val)), log(max(val)))
            if key in int_vars:
                space[key] = hp.uniformint(key, min(val), max(val))
            if key in uniform_vars:
                space[key] = hp.uniform(key, min(val), max(val))
            if key in choice_vars:
                space[key] = hp.choice(key, val)
        else:
            space[key] = val
    return space
[docs]
def hyper_tune(data: ClassificationProblem, max_evals=100, n_tries=1, random_search=False,
               model_args=None, train_args=None):
    _model_args = {'hidden_dim': (128, 256, 512, 1024), 'n_layers': (2, 4), 'dropout': (0, 1), 'batch_norm': (False, True)}
    _train_args = {'batch_size': 100000, 'num_epochs': 1000, 'patience': 20, 'lr': (1e-4, 1e-1)}
    model_space = _make_space(_model_args)
    train_space = _make_space(_train_args)
    objective = HyperTuneObjective(data, n_tries=n_tries)
    trials = Trials()
    in_dim = data.num_features
    out_dim = data.num_labels
    search_space = {
        **train_space,
        'model': scope.mlp_model(in_dim=in_dim,
                                 out_dim=out_dim,
                                 **model_space,
                                 )
    }
    def transform(space, value):
        value = space_eval(space, {key: val[0] for key, val in value.items() if val})
        if isinstance(value, torch.nn.Module):
            value = repr(value)
        return value
    if random_search:
        args = fmin(fn=objective, space=search_space, algo=rand.suggest, max_evals=max_evals, trials=trials)
    else:
        args = fmin(fn=objective, space=search_space, algo=tpe.suggest, max_evals=max_evals, trials=trials)
    params = {**train_space, **model_space}
    results = pd.DataFrame.from_records({key: transform(params[key], t['misc']['vals'])
                                         for key in t['misc']['vals']}
                                        for t in trials.trials)
    results['loss'] = trials.losses()
    best_args = space_eval(params, args)
    best_model_args = {key: best_args[key] for key in _model_args}
    best_train_args = {key: best_args[key] for key in _train_args}
    return best_model_args, best_train_args, results 
[docs]
def plot_hyper_results(results, plot_kws=None, diag_kws=None, **kwargs):
    import seaborn as sns
    import matplotlib.pyplot as plt
    plot_kws = {'size': 3, **(plot_kws if plot_kws is not None else {})}
    diag_kws = {'multiple': 'stack', **(diag_kws if diag_kws is not None else {})}
    if 'model' in results.columns:
        def key_fun(x):
            return [int("".join(val)) if key else "".join(val) for key, val in groupby(x, key=lambda x: x.isdigit())]
        results.model = pd.Categorical(results.model)
        results.model = results.model.cat.reorder_categories(sorted(results.model.cat.categories, key=key_fun),
                                                             ordered=True)
        f = sns.PairGrid(results, y_vars='loss', hue='model', **kwargs)
    else:
        f = sns.PairGrid(results, y_vars='loss', **kwargs)
    log_axes = {'epsilon', 'lr', 'weight_decay', 'alpha', 'xi', 'adam_epsilon', 'beta'}
    label_map = {'epsilon': r'$\epsilon$', 'lr': r'$\lambda$', 'weight_decay': r'$\omega$', 'alpha': r'$\alpha$', 'xi': r'$\xi$',
                 'hidden': r'$h$', 'beta': r'$\beta$', 'loss': 'validation error'}
    log2_axes = {}
    cat_axes = {'hidden'}
    def plot_fun(*args, **kwargs):
        ax = plt.gca()
        label = ax.get_xlabel()
        if label in cat_axes:
            sns.swarmplot(*args, **kwargs)
        else:
            kwargs = kwargs.copy()
            size = kwargs.get('s', kwargs['size'])
            kwargs['s'] = size ** 2
            del kwargs['size']
            sns.scatterplot(*args, **kwargs)
        # if ax.get_xlabel() == 'hidden':
        #     ax.set_xticks(hidden_vals)
        #     ax.set_xticklabels(hidden_vals)
        #     ax.set_xlim(0.9 * hidden_vals.min(), hidden_vals.max() / 0.9)
        #     ax.minorticks_off()
    f.map_offdiag(plot_fun, **plot_kws)
    f.map_diag(sns.histplot, **diag_kws)
    for axs in f.axes:
        for ax in axs:
            xlabel = ax.get_xlabel()
            if xlabel in log_axes:
                ax.set_xscale('log')
            if xlabel in log2_axes:
                ax.set_xscale('log', base=2)
                ax.xaxis.set_major_formatter('{x:.0f}')
            if xlabel in label_map:
                ax.set_xlabel(label_map[xlabel])
            ylabel = ax.get_ylabel()
            if ylabel in label_map:
                ax.set_ylabel(label_map[ylabel])
    return f