Source code for local2global_embedding.embedding.dgi.layers.gcn

import torch.nn as nn
import torch_geometric.nn as tg_nn


[docs] class GCN(nn.Module):
[docs] def __init__(self, in_ft, out_ft, act, bias=True): super(GCN, self).__init__() self.conv = tg_nn.GCNConv(in_channels=in_ft, out_channels=out_ft, bias=bias) self.act = nn.PReLU() if act == 'prelu' else act self.reset_parameters()
[docs] def reset_parameters(self): self.conv.reset_parameters() if hasattr(self.act, 'reset_parameters'): self.act.reset_parameters() elif isinstance(self.act, nn.PReLU): self.act.weight.data.fill_(0.25)
# Shape of seq: (batch, nodes, features)
[docs] def forward(self, seq, adj): out = self.conv(seq, adj) return self.act(out)