Source code for grb.model.dgl.gcn

import dgl
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GraphConv


[docs]class GCN(nn.Module): def __init__(self, in_features, out_features, hidden_features, activation=F.relu, layer_norm=False): super(GCN, self).__init__() self.layers = nn.ModuleList() if layer_norm: self.layers.append(nn.LayerNorm(in_features)) self.layers.append(GraphConv(in_features, hidden_features[0], activation=activation)) for i in range(len(hidden_features) - 1): if layer_norm: self.layers.append(nn.LayerNorm(hidden_features[i])) self.layers.append( GraphConv(hidden_features[i], hidden_features[i + 1], activation=activation)) self.layers.append(GraphConv(hidden_features[-1], out_features)) @property def model_type(self): return "dgl"
[docs] def forward(self, x, adj, dropout=0): graph = dgl.from_scipy(adj).to(x.device) graph.ndata['features'] = x for i, layer in enumerate(self.layers): if isinstance(layer, nn.LayerNorm): x = layer(x) else: x = layer(graph, x) if i != len(self.layers) - 1: x = F.dropout(x, dropout) return x