"""Torch module for GCN."""
import torch
import torch.nn as nn
import torch.nn.functional as F
from grb.utils.normalize import GCNAdjNorm
[docs]class GCN(nn.Module):
r"""
Description
-----------
Graph Convolutional Networks (`GCN <https://arxiv.org/abs/1609.02907>`__)
Parameters
----------
in_features : int
Dimension of input features.
out_features : int
Dimension of output features.
hidden_features : int or list of int
Dimension of hidden features. List if multi-layer.
n_layers : int
Number of layers.
layer_norm : bool, optional
Whether to use layer normalization. Default: ``False``.
activation : func of torch.nn.functional, optional
Activation function. Default: ``torch.nn.functional.relu``.
residual : bool, optional
Whether to use residual connection. Default: ``False``.
feat_norm : str, optional
Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``.
adj_norm_func : func of utils.normalize, optional
Function that normalizes adjacency matrix. Default: ``GCNAdjNorm``.
dropout : float, optional
Dropout rate during training. Default: ``0.0``.
"""
def __init__(self,
in_features,
out_features,
hidden_features,
n_layers,
activation=F.relu,
layer_norm=False,
residual=False,
feat_norm=None,
adj_norm_func=GCNAdjNorm,
dropout=0.0):
super(GCN, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.feat_norm = feat_norm
self.adj_norm_func = adj_norm_func
if type(hidden_features) is int:
hidden_features = [hidden_features] * (n_layers - 1)
elif type(hidden_features) is list or type(hidden_features) is tuple:
assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers."
n_features = [in_features] + hidden_features + [out_features]
self.layers = nn.ModuleList()
for i in range(n_layers):
if layer_norm:
self.layers.append(nn.LayerNorm(n_features[i]))
self.layers.append(GCNConv(in_features=n_features[i],
out_features=n_features[i + 1],
activation=activation if i != n_layers - 1 else None,
residual=residual if i != n_layers - 1 else False,
dropout=dropout if i != n_layers - 1 else 0.0))
self.reset_parameters()
@property
def model_type(self):
"""Indicate type of implementation."""
return "torch"
@property
def model_name(self):
return "gcn"
[docs] def reset_parameters(self):
"""Reset parameters."""
for layer in self.layers:
layer.reset_parameters()
[docs] def forward(self, x, adj):
r"""
Parameters
----------
x : torch.Tensor
Tensor of input features.
adj : torch.SparseTensor
Sparse tensor of adjacency matrix.
Returns
-------
x : torch.Tensor
Output of model (logits without activation).
"""
for layer in self.layers:
if isinstance(layer, nn.LayerNorm):
x = layer(x)
else:
x = layer(x, adj)
return x
[docs]class GCNGC(nn.Module):
r"""
Description
-----------
Graph Convolutional Networks (`GCN <https://arxiv.org/abs/1609.02907>`__)
Parameters
----------
in_features : int
Dimension of input features.
out_features : int
Dimension of output features.
hidden_features : int or list of int
Dimension of hidden features. List if multi-layer.
n_layers : int
Number of layers.
layer_norm : bool, optional
Whether to use layer normalization. Default: ``False``.
activation : func of torch.nn.functional, optional
Activation function. Default: ``torch.nn.functional.relu``.
residual : bool, optional
Whether to use residual connection. Default: ``False``.
feat_norm : str, optional
Type of features normalization, choose from ["arctan", "tanh", None]. Default: ``None``.
adj_norm_func : func of utils.normalize, optional
Function that normalizes adjacency matrix. Default: ``GCNAdjNorm``.
dropout : float, optional
Dropout rate during training. Default: ``0.0``.
"""
def __init__(self,
in_features,
out_features,
hidden_features,
n_layers,
activation=F.relu,
layer_norm=False,
residual=False,
feat_norm=None,
adj_norm_func=GCNAdjNorm,
dropout=0.0):
super(GCNGC, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.feat_norm = feat_norm
self.adj_norm_func = adj_norm_func
if type(hidden_features) is int:
hidden_features = [hidden_features] * (n_layers - 1)
elif type(hidden_features) is list or type(hidden_features) is tuple:
assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers."
n_features = [in_features] + hidden_features
self.layers = nn.ModuleList()
for i in range(n_layers - 1):
if layer_norm:
self.layers.append(nn.LayerNorm(n_features[i]))
self.layers.append(GCNConv(in_features=n_features[i],
out_features=n_features[i + 1],
activation=activation,
residual=residual,
dropout=dropout))
self.linear = nn.Linear(hidden_features[-1], out_features)
self.dropout = nn.Dropout(dropout)
self.reset_parameters()
@property
def model_type(self):
"""Indicate type of implementation."""
return "torch"
@property
def model_name(self):
return "gcn"
[docs] def reset_parameters(self):
"""Reset parameters."""
for layer in self.layers:
layer.reset_parameters()
[docs] def forward(self, x, adj, batch_index=None):
r"""
Parameters
----------
x : torch.Tensor
Tensor of input features.
adj : torch.SparseTensor
Sparse tensor of adjacency matrix.
Returns
-------
x : torch.Tensor
Output of model (logits without activation).
"""
for layer in self.layers:
if isinstance(layer, nn.LayerNorm):
x = layer(x)
else:
x = layer(x, adj)
if batch_index is not None:
batch_size = int(torch.max(batch_index)) + 1
out = torch.zeros(batch_size, x.shape[1]).to(x.device)
out = out.scatter_add_(dim=0, index=batch_index.view(-1, 1).repeat(1, x.shape[1]), src=x)
else:
out = torch.sum(x, dim=0)
out = self.dropout(self.linear(out))
return out
[docs]class GCNConv(nn.Module):
r"""
Description
-----------
GCN convolutional layer.
Parameters
----------
in_features : int
Dimension of input features.
out_features : int
Dimension of output features.
activation : func of torch.nn.functional, optional
Activation function. Default: ``None``.
residual : bool, optional
Whether to use residual connection. Default: ``False``.
dropout : float, optional
Dropout rate during training. Default: ``0.0``.
"""
def __init__(self,
in_features,
out_features,
activation=None,
residual=False,
dropout=0.0):
super(GCNConv, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.linear = nn.Linear(in_features, out_features)
if residual:
self.residual = nn.Linear(in_features, out_features)
else:
self.residual = None
self.activation = activation
if dropout > 0.0:
self.dropout = nn.Dropout(dropout)
else:
self.dropout = None
self.reset_parameters()
[docs] def reset_parameters(self):
"""Reset parameters."""
if self.activation == F.leaky_relu:
gain = nn.init.calculate_gain('leaky_relu')
else:
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.linear.weight, gain=gain)
[docs] def forward(self, x, adj):
r"""
Parameters
----------
x : torch.Tensor
Tensor of input features.
adj : torch.SparseTensor
Sparse tensor of adjacency matrix.
Returns
-------
x : torch.Tensor
Output of layer.
"""
x = self.linear(x)
x = torch.sparse.mm(adj, x)
if self.activation is not None:
x = self.activation(x)
if self.residual is not None:
x = x + self.residual(x)
if self.dropout is not None:
x = self.dropout(x)
return x