import torch.nn as nn
from ..layers import GCN, AvgReadout, Discriminator
[docs]
class DGI(nn.Module):
[docs]
def __init__(self, n_in, n_h, activation='prelu'):
super(DGI, self).__init__()
self.gcn = GCN(n_in, n_h, activation)
self.read = AvgReadout()
self.sigm = nn.Sigmoid()
self.disc = Discriminator(n_h)
[docs]
def reset_parameters(self):
for m in self.children():
if hasattr(m, 'reset_parameters'):
m.reset_parameters()
[docs]
def forward(self, seq1, seq2, adj, msk, samp_bias1, samp_bias2):
h_1 = self.gcn(seq1, adj)
c = self.read(h_1, msk)
c = self.sigm(c)
h_2 = self.gcn(seq2, adj)
ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2)
return ret
# Detach the return variables
[docs]
def embed(self, data, msk=None):
h_1 = self.gcn(data.x, data.edge_index)
return h_1.detach()