import torch_geometric as tg
import torch
[docs]
class DGILoss(torch.nn.Module):
[docs]
    def __init__(self):
        super().__init__()
        self.loss_fun = torch.nn.BCEWithLogitsLoss() 
[docs]
    def forward(self, model, data: tg.data.Data):
        device = data.edge_index.device
        nb_nodes = data.num_nodes
        idx = torch.randperm(nb_nodes, device=device)
        shuf_fts = data.x[idx, :]
        lbl_1 = torch.ones(nb_nodes, device=device)
        lbl_2 = torch.zeros(nb_nodes, device=device)
        lbl = torch.cat((lbl_1, lbl_2), 0)
        logits = model(data.x, shuf_fts, data.edge_index, None, None, None)
        return self.loss_fun(logits, lbl)