Source code for local2global_embedding.embedding.dgi.layers.readout
import torch
import torch.nn as nn
# Applies an average on seq, of shape (batch, nodes, features)
# While taking into account the masking of msk
import torch
import torch.nn as nn
# Applies an average on seq, of shape (batch, nodes, features)
# While taking into account the masking of msk