Source code for local2global_embedding.embedding.dgi.utils.loss

import torch_geometric as tg
import torch


[docs] class DGILoss(torch.nn.Module):
[docs] def __init__(self): super().__init__() self.loss_fun = torch.nn.BCEWithLogitsLoss()
[docs] def forward(self, model, data: tg.data.Data): device = data.edge_index.device nb_nodes = data.num_nodes idx = torch.randperm(nb_nodes, device=device) shuf_fts = data.x[idx, :] lbl_1 = torch.ones(nb_nodes, device=device) lbl_2 = torch.zeros(nb_nodes, device=device) lbl = torch.cat((lbl_1, lbl_2), 0) logits = model(data.x, shuf_fts, data.edge_index, None, None, None) return self.loss_fun(logits, lbl)