import torch.nn as nn
import torch_geometric.nn as tg_nn
[docs]
class GCN(nn.Module):
[docs]
    def __init__(self, in_ft, out_ft, act, bias=True):
        super(GCN, self).__init__()
        self.conv = tg_nn.GCNConv(in_channels=in_ft, out_channels=out_ft, bias=bias)
        self.act = nn.PReLU() if act == 'prelu' else act
        self.reset_parameters() 
[docs]
    def reset_parameters(self):
        self.conv.reset_parameters()
        if hasattr(self.act, 'reset_parameters'):
            self.act.reset_parameters()
        elif isinstance(self.act, nn.PReLU):
            self.act.weight.data.fill_(0.25) 
    # Shape of seq: (batch, nodes, features)
[docs]
    def forward(self, seq, adj):
        out = self.conv(seq, adj)
        
        return self.act(out)