import torch
import torch.nn as nn
[docs]
class Discriminator(nn.Module):
[docs]
def __init__(self, n_h):
super(Discriminator, self).__init__()
self.f_k = nn.Bilinear(n_h, n_h, 1)
self.reset_parameters()
[docs]
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Bilinear):
torch.nn.init.xavier_uniform_(m.weight.data)
if m.bias is not None:
m.bias.data.fill_(0.0)
[docs]
def forward(self, c, h_pl, h_mi, s_bias1=None, s_bias2=None):
c_x = torch.unsqueeze(c, 0)
c_x = c_x.expand_as(h_pl)
sc_1 = torch.squeeze(self.f_k(h_pl, c_x), 1)
sc_2 = torch.squeeze(self.f_k(h_mi, c_x), 1)
if s_bias1 is not None:
sc_1 += s_bias1
if s_bias2 is not None:
sc_2 += s_bias2
logits = torch.cat((sc_1, sc_2), 0)
return logits