Source code for grb.model.dgl.grand

import dgl
import torch
import numpy as np
import torch.nn as nn
import dgl.function as fn

from grb.model.torch import MLP


[docs]class GRAND(nn.Module): r""" Description ----------- Graph Random Neural Networks (`GRAND <https://arxiv.org/abs/2005.11079>`__) Parameters ----------- in_features : int Dimension of input features. out_features : int Dimension of output features. hidden_features : int or list of int Dimension of hidden features. List if multi-layer. n_layers : int Number of layers. s: int Number of Augmentation samples k: int Number of Propagation Steps node_dropout: float Dropout rate on node features. input_dropout: float Dropout rate of the input layer of a MLP hidden_dropout: float Dropout rate of the hidden layer of a MLP """ def __init__(self, in_features, out_features, hidden_features, n_layers=2, s=1, k=3, temp=1.0, lam=1.0, feat_norm=None, adj_norm_func=None, node_dropout=0.0, input_dropout=0.0, hidden_dropout=0.0): super(GRAND, self).__init__() self.in_features = in_features self.out_features = out_features self.feat_norm = feat_norm self.adj_norm_func = adj_norm_func self.s = s self.k = k self.temp = temp self.lam = lam self.mlp = MLP(in_features=in_features, out_features=out_features, hidden_features=hidden_features, n_layers=n_layers, dropout=hidden_dropout) self.node_dropout = node_dropout if input_dropout > 0.0: self.input_dropout = nn.Dropout(input_dropout) else: self.input_dropout = None
[docs] def forward(self, x, adj): graph = dgl.from_scipy(adj).to(x.device) graph.ndata['features'] = x if self.input_dropout is not None: x = self.input_dropout(x) if self.training: output_list = [] for s in range(self.s): x_drop = drop_node(x, self.node_dropout, training=True) x_drop = GRANDConv(graph, x_drop, hop=self.k) output_list.append(torch.log_softmax(self.mlp(x_drop), dim=-1)) return output_list else: x = GRANDConv(graph, x, self.k) x = self.mlp(x) return x
@property def model_type(self): return "dgl" @property def model_name(self): return "grand"
[docs]def drop_node(features, drop_rate, training=True): n = features.shape[0] drop_rates = torch.FloatTensor(np.ones(n) * drop_rate) if training: masks = torch.bernoulli(1. - drop_rates).unsqueeze(1) features = masks.to(features.device) * features return features
[docs]def GRANDConv(graph, features, hop): r""" Parameters ----------- graph : dgl.Graph The input graph features : torch.Tensor Tensor of node features hop : int Propagation Steps """ with graph.local_scope(): degs = graph.in_degrees().float().clamp(min=1) norm = torch.pow(degs, -0.5).to(features.device).unsqueeze(1) graph.ndata['norm'] = norm graph.apply_edges(fn.u_mul_v('norm', 'norm', 'weight')) x = features y = 0.0 + features for i in range(hop): graph.ndata['h'] = x graph.update_all(fn.u_mul_e('h', 'weight', 'm'), fn.sum('m', 'h')) x = graph.ndata.pop('h') y.add_(x) return y / (hop + 1)