Source code for grb.model.torch.gcn

"""Torch module for GCN."""
import torch
import torch.nn as nn
import torch.nn.functional as F

from grb.utils.normalize import GCNAdjNorm


[docs]class GCN(nn.Module): r""" Description ----------- Graph Convolutional Networks (`GCN <https://arxiv.org/abs/1609.02907>`__) Parameters ---------- in_features : int Dimension of input features. out_features : int Dimension of output features. hidden_features : int or list of int Dimension of hidden features. List if multi-layer. n_layers : int Number of layers. layer_norm : bool, optional Whether to use layer normalization. Default: ``False``. activation : func of torch.nn.functional, optional Activation function. Default: ``torch.nn.functional.relu``. residual : bool, optional Whether to use residual connection. Default: ``False``. feat_norm : str, optional Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``. adj_norm_func : func of utils.normalize, optional Function that normalizes adjacency matrix. Default: ``GCNAdjNorm``. dropout : float, optional Dropout rate during training. Default: ``0.0``. """ def __init__(self, in_features, out_features, hidden_features, n_layers, activation=F.relu, layer_norm=False, residual=False, feat_norm=None, adj_norm_func=GCNAdjNorm, dropout=0.0): super(GCN, self).__init__() self.in_features = in_features self.out_features = out_features self.feat_norm = feat_norm self.adj_norm_func = adj_norm_func if type(hidden_features) is int: hidden_features = [hidden_features] * (n_layers - 1) elif type(hidden_features) is list or type(hidden_features) is tuple: assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." n_features = [in_features] + hidden_features + [out_features] self.layers = nn.ModuleList() for i in range(n_layers): if layer_norm: self.layers.append(nn.LayerNorm(n_features[i])) self.layers.append(GCNConv(in_features=n_features[i], out_features=n_features[i + 1], activation=activation if i != n_layers - 1 else None, residual=residual if i != n_layers - 1 else False, dropout=dropout if i != n_layers - 1 else 0.0)) self.reset_parameters() @property def model_type(self): """Indicate type of implementation.""" return "torch" @property def model_name(self): return "gcn"
[docs] def reset_parameters(self): """Reset parameters.""" for layer in self.layers: layer.reset_parameters()
[docs] def forward(self, x, adj): r""" Parameters ---------- x : torch.Tensor Tensor of input features. adj : torch.SparseTensor Sparse tensor of adjacency matrix. Returns ------- x : torch.Tensor Output of model (logits without activation). """ for layer in self.layers: if isinstance(layer, nn.LayerNorm): x = layer(x) else: x = layer(x, adj) return x
[docs]class GCNGC(nn.Module): r""" Description ----------- Graph Convolutional Networks (`GCN <https://arxiv.org/abs/1609.02907>`__) Parameters ---------- in_features : int Dimension of input features. out_features : int Dimension of output features. hidden_features : int or list of int Dimension of hidden features. List if multi-layer. n_layers : int Number of layers. layer_norm : bool, optional Whether to use layer normalization. Default: ``False``. activation : func of torch.nn.functional, optional Activation function. Default: ``torch.nn.functional.relu``. residual : bool, optional Whether to use residual connection. Default: ``False``. feat_norm : str, optional Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``. adj_norm_func : func of utils.normalize, optional Function that normalizes adjacency matrix. Default: ``GCNAdjNorm``. dropout : float, optional Dropout rate during training. Default: ``0.0``. """ def __init__(self, in_features, out_features, hidden_features, n_layers, activation=F.relu, layer_norm=False, residual=False, feat_norm=None, adj_norm_func=GCNAdjNorm, dropout=0.0): super(GCNGC, self).__init__() self.in_features = in_features self.out_features = out_features self.feat_norm = feat_norm self.adj_norm_func = adj_norm_func if type(hidden_features) is int: hidden_features = [hidden_features] * (n_layers - 1) elif type(hidden_features) is list or type(hidden_features) is tuple: assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." n_features = [in_features] + hidden_features self.layers = nn.ModuleList() for i in range(n_layers - 1): if layer_norm: self.layers.append(nn.LayerNorm(n_features[i])) self.layers.append(GCNConv(in_features=n_features[i], out_features=n_features[i + 1], activation=activation, residual=residual, dropout=dropout)) self.linear = nn.Linear(hidden_features[-1], out_features) self.dropout = nn.Dropout(dropout) self.reset_parameters() @property def model_type(self): """Indicate type of implementation.""" return "torch" @property def model_name(self): return "gcn"
[docs] def reset_parameters(self): """Reset parameters.""" for layer in self.layers: layer.reset_parameters()
[docs] def forward(self, x, adj, batch_index=None): r""" Parameters ---------- x : torch.Tensor Tensor of input features. adj : torch.SparseTensor Sparse tensor of adjacency matrix. Returns ------- x : torch.Tensor Output of model (logits without activation). """ for layer in self.layers: if isinstance(layer, nn.LayerNorm): x = layer(x) else: x = layer(x, adj) if batch_index is not None: batch_size = int(torch.max(batch_index)) + 1 out = torch.zeros(batch_size, x.shape[1]).to(x.device) out = out.scatter_add_(dim=0, index=batch_index.view(-1, 1).repeat(1, x.shape[1]), src=x) else: out = torch.sum(x, dim=0) out = self.dropout(self.linear(out)) return out
[docs]class GCNConv(nn.Module): r""" Description ----------- GCN convolutional layer. Parameters ---------- in_features : int Dimension of input features. out_features : int Dimension of output features. activation : func of torch.nn.functional, optional Activation function. Default: ``None``. residual : bool, optional Whether to use residual connection. Default: ``False``. dropout : float, optional Dropout rate during training. Default: ``0.0``. """ def __init__(self, in_features, out_features, activation=None, residual=False, dropout=0.0): super(GCNConv, self).__init__() self.in_features = in_features self.out_features = out_features self.linear = nn.Linear(in_features, out_features) if residual: self.residual = nn.Linear(in_features, out_features) else: self.residual = None self.activation = activation if dropout > 0.0: self.dropout = nn.Dropout(dropout) else: self.dropout = None self.reset_parameters()
[docs] def reset_parameters(self): """Reset parameters.""" if self.activation == F.leaky_relu: gain = nn.init.calculate_gain('leaky_relu') else: gain = nn.init.calculate_gain('relu') nn.init.xavier_normal_(self.linear.weight, gain=gain)
[docs] def forward(self, x, adj): r""" Parameters ---------- x : torch.Tensor Tensor of input features. adj : torch.SparseTensor Sparse tensor of adjacency matrix. Returns ------- x : torch.Tensor Output of layer. """ x = self.linear(x) x = torch.sparse.mm(adj, x) if self.activation is not None: x = self.activation(x) if self.residual is not None: x = x + self.residual(x) if self.dropout is not None: x = self.dropout(x) return x