Source code for local2global_embedding.embedding.dgi.models.logreg
import torch
import torch.nn as nn
import torch.nn.functional as F
[docs]
class LogReg(nn.Module):
[docs]
def __init__(self, ft_in, nb_classes):
super(LogReg, self).__init__()
self.fc = nn.Linear(ft_in, nb_classes)
for m in self.modules():
self.weights_init(m)
[docs]
def weights_init(self, m):
if isinstance(m, nn.Linear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)