Skip to content

Commit

Permalink
update GCN
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGodder committed Dec 9, 2019
1 parent 8aedde8 commit 1becba5
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 1 deletion.
1 change: 0 additions & 1 deletion torchcmh/models/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
capsnet.py
resnet_capsnet.py
CompactBilinearPooling.py
GCN.py
GCNText.py
GCMH
QDCMH
Expand Down
46 changes: 46 additions & 0 deletions torchcmh/models/GCN.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# @Time : 2019/8/28
# @Author : Godder
# @Github : https://github.com/WangGodder
import math
import torch
from torch.nn.parameter import Parameter
from torch import nn


class GraphConvolution(nn.Module):
"""
Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
"""

def __init__(self, in_features, out_features, bias=True):
super(GraphConvolution, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
if bias:
self.bias = Parameter(torch.FloatTensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()

def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)

def forward(self, input, adj):
support = torch.mm(input, self.weight)
output = torch.spmm(adj, support)
if self.bias is not None:
return output + self.bias
else:
return output

def __repr__(self):
return self.__class__.__name__ + ' (' \
+ str(self.in_features) + ' -> ' \
+ str(self.out_features) + ')'


0 comments on commit 1becba5

Please sign in to comment.