Source code for local2global_embedding.embedding.dgi.layers.discriminator

import torch
import torch.nn as nn


[docs] class Discriminator(nn.Module):
[docs] def __init__(self, n_h): super(Discriminator, self).__init__() self.f_k = nn.Bilinear(n_h, n_h, 1) self.reset_parameters()
[docs] def reset_parameters(self): for m in self.modules(): if isinstance(m, nn.Bilinear): torch.nn.init.xavier_uniform_(m.weight.data) if m.bias is not None: m.bias.data.fill_(0.0)
[docs] def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None): c_x = torch.unsqueeze(c, 0) c_x = c_x.expand_as(h_pl) sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1) sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1) if s_bias1 is not None: sc_1 += s_bias1 if s_bias2 is not None: sc_2 += s_bias2 logits = torch.cat((sc_1, sc_2), 0) return logits