Source code for local2global_embedding.embedding.dgi.models.dgi

import torch.nn as nn
from ..layers import GCN, AvgReadout, Discriminator


[docs] class DGI(nn.Module):
[docs] def __init__(self, n_in, n_h, activation='prelu'): super(DGI, self).__init__() self.gcn = GCN(n_in, n_h, activation) self.read = AvgReadout() self.sigm = nn.Sigmoid() self.disc = Discriminator(n_h)
[docs] def reset_parameters(self): for m in self.children(): if hasattr(m, 'reset_parameters'): m.reset_parameters()
[docs] def forward(self, seq1, seq2, adj, msk, samp_bias1, samp_bias2): h_1 = self.gcn(seq1, adj) c = self.read(h_1, msk) c = self.sigm(c) h_2 = self.gcn(seq2, adj) ret = self.disc(c, h_1, h_2, samp_bias1, samp_bias2) return ret
# Detach the return variables
[docs] def embed(self, data, msk=None): h_1 = self.gcn(data.x, data.edge_index) return h_1.detach()