Source code for local2global_embedding.embedding.dgi.models.logreg

import torch
import torch.nn as nn
import torch.nn.functional as F

[docs] class LogReg(nn.Module):
[docs] def __init__(self, ft_in, nb_classes): super(LogReg, self).__init__() self.fc = nn.Linear(ft_in, nb_classes) for m in self.modules(): self.weights_init(m)
[docs] def weights_init(self, m): if isinstance(m, nn.Linear): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0)
[docs] def forward(self, seq): ret = self.fc(seq) return ret