Source code for grb.model.dgl.gat

import dgl
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn.pytorch import GATConv


[docs]class GAT(nn.Module): def __init__(self, in_features, out_features, hidden_features, num_heads, activation=F.leaky_relu, layer_norm=False): super(GAT, self).__init__() self.layers = nn.ModuleList() if layer_norm: self.layers.append(nn.LayerNorm(in_features)) self.layers.append(GATConv(in_features, hidden_features[0], num_heads, activation=activation)) for i in range(len(hidden_features) - 1): if layer_norm: self.layers.append(nn.LayerNorm(hidden_features[i] * num_heads)) self.layers.append( GATConv(hidden_features[i] * num_heads, hidden_features[i + 1], num_heads, activation=activation)) self.layers.append(GATConv(hidden_features[-1] * num_heads, num_heads, out_features)) @property def model_type(self): return "dgl"
[docs] def forward(self, x, adj, dropout=0): graph = dgl.from_scipy(adj).to(x.device) graph.ndata['features'] = x for i, layer in enumerate(self.layers): if isinstance(layer, nn.LayerNorm): x = layer(x) else: x = layer(graph, x).flatten(1) if i != len(self.layers) - 1: x = F.dropout(x, dropout) return x