Source code for grb.model.dgl.gin

import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch.conv import GINConv
from dgl.nn.pytorch.glob import SumPooling, AvgPooling, MaxPooling


[docs]class ApplyNodeFunc(nn.Module): """Update the node feature hv with MLP, BN and ReLU.""" def __init__(self, mlp): super(ApplyNodeFunc, self).__init__() self.mlp = mlp self.bn = nn.BatchNorm1d(self.mlp.output_dim)
[docs] def forward(self, h): h = self.mlp(h) h = self.bn(h) h = F.relu(h) return h
[docs]class MLP(nn.Module): """MLP with linear output""" def __init__(self, num_layers, input_dim, hidden_dim, output_dim): """MLP layers construction Paramters --------- num_layers: int The number of linear layers input_dim: int The dimensionality of input features hidden_dim: int The dimensionality of hidden units at ALL layers output_dim: int The number of classes for prediction """ super(MLP, self).__init__() self.linear_or_not = True # default is linear model self.num_layers = num_layers self.output_dim = output_dim if num_layers < 1: raise ValueError("number of layers should be positive!") elif num_layers == 1: # Linear model self.linear = nn.Linear(input_dim, output_dim) else: # Multi-layer model self.linear_or_not = False self.linears = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() self.linears.append(nn.Linear(input_dim, hidden_dim)) for layer in range(num_layers - 2): self.linears.append(nn.Linear(hidden_dim, hidden_dim)) self.linears.append(nn.Linear(hidden_dim, output_dim)) for layer in range(num_layers - 1): self.batch_norms.append(nn.BatchNorm1d((hidden_dim)))
[docs] def forward(self, x): if self.linear_or_not: # If linear model return self.linear(x) else: # If MLP h = x for i in range(self.num_layers - 1): h = F.relu(self.batch_norms[i](self.linears[i](h))) return self.linears[-1](h)
[docs]class GIN(nn.Module): """GIN model""" def __init__(self, in_features, hidden_features, out_features, learn_eps=True, neighbor_pooling_type='sum', num_mlp_layers=1): super(GIN, self).__init__() self.learn_eps = learn_eps # List of MLPs self.layers = torch.nn.ModuleList() self.batch_norms = torch.nn.ModuleList() for i in range(len(hidden_features)): if i == 0: mlp = MLP(num_mlp_layers, in_features, hidden_features[i], hidden_features[i]) else: mlp = MLP(num_mlp_layers, hidden_features[i], hidden_features[i], hidden_features[i]) self.layers.append( GINConv(ApplyNodeFunc(mlp), neighbor_pooling_type, 0, self.learn_eps)) self.batch_norms.append(nn.BatchNorm1d(hidden_features[i])) self.linear1 = nn.Linear(hidden_features[-2], hidden_features[-1]) self.linear2 = nn.Linear(hidden_features[-1], out_features) @property def model_type(self): return "dgl"
[docs] def forward(self, x, adj, dropout=0): graph = dgl.from_scipy(adj).to(x.device) graph.ndata['features'] = x for i in range(len(self.layers) - 1): x = self.layers[i](graph, x) x = self.batch_norms[i](x) x = F.relu(x) x = F.relu(self.linear1(x)) x = F.dropout(x, dropout) x = self.linear2(x) return x