Source code for grb.defense.gcnsvd

import numpy as np
import scipy.sparse as sp
import torch.nn as nn
import torch.nn.functional as F

import grb.utils as utils
from grb.model.torch.gcn import GCNConv
from grb.utils.normalize import GCNAdjNorm


[docs]class GCNSVD(nn.Module): def __init__(self, in_features, out_features, hidden_features, n_layers, activation=F.relu, layer_norm=False, feat_norm=None, adj_norm_func=None, residual=False, dropout=0.0, k=50): super(GCNSVD, self).__init__() self.in_features = in_features self.out_features = out_features self.feat_norm = feat_norm self.adj_norm_func = adj_norm_func if type(hidden_features) is int: hidden_features = [hidden_features] * (n_layers - 1) elif type(hidden_features) is list or type(hidden_features) is tuple: assert len(hidden_features) == (n_layers - 1), "Incompatible sizes between hidden_features and n_layers." n_features = [in_features] + hidden_features + [out_features] self.layers = nn.ModuleList() for i in range(n_layers): if layer_norm: self.layers.append(nn.LayerNorm(n_features[i])) self.layers.append(GCNConv(in_features=n_features[i], out_features=n_features[i + 1], activation=activation if i != n_layers - 1 else None, residual=residual if i != n_layers - 1 else False, dropout=dropout if i != n_layers - 1 else 0.0)) self.k = k @property def model_type(self): return "torch"
[docs] def forward(self, x, adj): adj = self.truncatedSVD(adj, self.k) adj = utils.adj_preprocess(adj=adj, adj_norm_func=self.adj_norm_func, device=x.device) for layer in self.layers: if isinstance(layer, nn.LayerNorm): x = layer(x) else: x = layer(x, adj) return x
[docs] def truncatedSVD(self, adj, k=50): edge_index = adj._indices() row, col = edge_index[0].cpu().data.numpy()[:], edge_index[1].cpu().data.numpy()[:] adj = sp.csr_matrix((np.ones(len(row)), (row, col))) if sp.issparse(adj): adj = adj.asfptype() U, S, V = sp.linalg.svds(adj, k=k) diag_S = np.diag(S) else: U, S, V = np.linalg.svd(adj) U = U[:, :k] S = S[:k] V = V[:k, :] diag_S = np.diag(S) new_adj = U @ diag_S @ V new_adj = sp.csr_matrix(new_adj) return new_adj