Source code for local2global_embedding.embedding.dgi.layers.readout

import torch
import torch.nn as nn


# Applies an average on seq, of shape (batch, nodes, features)
# While taking into account the masking of msk
[docs] class AvgReadout(nn.Module):
[docs] def __init__(self): super(AvgReadout, self).__init__()
[docs] def forward(self, seq, msk): if msk is None: return torch.mean(seq, 0) else: msk = torch.unsqueeze(msk, -1) return torch.sum(seq * msk, 0) / torch.sum(msk)