-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
update readme file and some training methods.
- Loading branch information
1 parent
0ccbb24
commit fc4467e
Showing
16 changed files
with
426 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2019/10/28 | ||
# @Author : Godder | ||
# @Github : https://github.com/WangGodder | ||
import torch | ||
__all__ = ['focal_loss', 'cosine', 'vector_length'] | ||
|
||
|
||
def focal_loss(logit: torch.Tensor, gamma, alpha=1, eps=1e-5): | ||
""" | ||
focal loss: /alpha * (1 - logit)^{gamma} * log^{logit} | ||
:param logit: logit value 0 ~ 1 | ||
:param gamma: | ||
:param alpha: | ||
:param eps: a tiny value prevent log(0) | ||
:return: | ||
""" | ||
return alpha * -torch.pow(1 - logit, gamma) * torch.log(logit + eps) | ||
|
||
|
||
def cosine(hash1: torch.Tensor, hash2: torch.Tensor): | ||
inter = torch.matmul(hash1, hash2.t()) | ||
length1 = vector_length(hash1, keepdim=True) | ||
length2 = vector_length(hash2, keepdim=True) | ||
return torch.div(inter, torch.matmul(length1, length2.t())) | ||
|
||
|
||
def vector_length(vector: torch.Tensor, keepdim=False): | ||
if len(vector.shape) > 2: | ||
vector = vector.unsqueeze(0) | ||
return torch.sqrt(torch.sum(torch.pow(vector, 2), dim=-1, keepdim=keepdim)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,4 +11,5 @@ MSText_capsnet.py | |
MSText_norm.py | ||
pcb.py | ||
resnet_gcn.py | ||
resnet_norm.py | ||
pretrain_model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2019/10/29 | ||
# @Author : Godder | ||
# @Github : https://github.com/WangGodder | ||
from torchcmh.models import BasicModule | ||
from torchcmh.models.GCN import GraphConvolution | ||
from torch import nn | ||
import torch | ||
|
||
|
||
class GCN(BasicModule): | ||
hidden_layer_dim = 1024 | ||
|
||
def __init__(self, bit, label_dim): | ||
super(GCN, self).__init__() | ||
self.module_name = "GCH.GCN" | ||
self.gcn1 = GraphConvolution(bit, self.hidden_layer_dim) | ||
self.gcn2 = GraphConvolution(self.hidden_layer_dim, bit) | ||
self.hash_layer = nn.Linear(bit, bit) | ||
self.label_layer = nn.Linear(bit, label_dim) | ||
|
||
def forward(self, x, adjacent_matrix): | ||
feature = self.gcn1(x, adjacent_matrix) | ||
feature = torch.relu_(feature) | ||
feature = self.gcn2(feature, adjacent_matrix) | ||
feature = torch.sigmoid(feature) | ||
# hash represent | ||
hash_represent = self.hash_layer(feature) | ||
# label generation | ||
label = self.label_layer(hash_represent) | ||
# activate | ||
hash_represent = torch.tanh(hash_represent) | ||
label = torch.sigmoid(label) | ||
return hash_represent, label |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2019/10/29 | ||
# @Author : Godder | ||
# @Github : https://github.com/WangGodder |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2019/10/29 | ||
# @Author : Godder | ||
# @Github : https://github.com/WangGodder | ||
from torchcmh.models import BasicModule | ||
from torch import nn | ||
import torch | ||
__all__ = ['get_label_net', 'get_txt_net'] | ||
|
||
|
||
class EmbeddingNet(BasicModule): | ||
hidden_dim1 = 512 | ||
hidden_dim2 = 512 | ||
|
||
def __init__(self, embedding_dim, bit, label_dim): | ||
super(EmbeddingNet, self).__init__() | ||
fc1 = nn.Linear(embedding_dim, self.hidden_dim1) | ||
fc2 = nn.Linear(self.hidden_dim1, self.hidden_dim2) | ||
fc3 = nn.Linear(self.hidden_dim2, bit) | ||
relu = nn.ReLU(inplace=True) | ||
self.fc = nn.Sequential(fc1, relu, fc2, relu, fc3) | ||
self.label_layer = nn.Linear(bit, label_dim) | ||
|
||
def forward(self, x: torch.Tensor): | ||
if len(x.shape) > 2: | ||
x = x.squeeze() | ||
hash_represent = self.fc(x) | ||
if self.training: | ||
label_generation = self.label_layer(hash_represent) | ||
hash_represent = torch.tanh(hash_represent) | ||
label_generation = torch.sigmoid(label_generation) | ||
return hash_represent, label_generation | ||
else: | ||
return hash_represent | ||
|
||
|
||
def get_label_net(label_dim, bit): | ||
label_net = EmbeddingNet(label_dim, bit, label_dim) | ||
label_net.module_name = "GCH.LabelNet" | ||
return label_net | ||
|
||
|
||
def get_txt_net(txt_dim, bit, label_dim): | ||
txt_net = EmbeddingNet(txt_dim, bit, label_dim) | ||
txt_net.module_name = "GCH.TextNet" | ||
return txt_net |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# -*- coding: utf-8 -*- | ||
# @Time : 2019/10/29 | ||
# @Author : Godder | ||
# @Github : https://github.com/WangGodder | ||
import torch.nn as nn | ||
import torch | ||
from torchcmh.models import abs_dir, BasicModule | ||
import os | ||
__all__ = ['get_image_net'] | ||
|
||
pretrain_model = os.path.join(abs_dir, "pretrain_model", "imagenet-vgg-f.pth") | ||
|
||
|
||
class VGG_F(BasicModule): | ||
def __init__(self, bit, label_dim): | ||
super(VGG_F, self).__init__() | ||
self.module_name = "GCH.vgg-f" | ||
self.features = nn.Sequential( | ||
# 0 conv1 | ||
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=11, stride=4), | ||
# 1 relu1 | ||
nn.ReLU(inplace=True), | ||
# 2 norm1 | ||
nn.LocalResponseNorm(size=2, k=2), | ||
# 3 pool1 | ||
nn.ZeroPad2d((0, 1, 0, 1)), | ||
nn.MaxPool2d(kernel_size=(3, 3), stride=2), | ||
# 4 conv2 | ||
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=5, stride=1, padding=2), | ||
# 5 relu2 | ||
nn.ReLU(inplace=True), | ||
# 6 norm2 | ||
nn.LocalResponseNorm(size=2, k=2), | ||
# 7 pool2 | ||
nn.MaxPool2d(kernel_size=(3, 3), stride=2), | ||
# 8 conv3 | ||
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | ||
# 9 relu3 | ||
nn.ReLU(inplace=True), | ||
# 10 conv4 | ||
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | ||
# 11 relu4 | ||
nn.ReLU(inplace=True), | ||
# 12 conv5 | ||
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1), | ||
# 13 relu5 | ||
nn.ReLU(inplace=True), | ||
# 14 pool5 | ||
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)), | ||
# 15 full_conv6 | ||
nn.Conv2d(in_channels=256, out_channels=4096, kernel_size=6), | ||
# 16 relu6 | ||
nn.ReLU(inplace=True), | ||
# 17 full_conv7 | ||
nn.Conv2d(in_channels=4096, out_channels=4096, kernel_size=1), | ||
# 18 relu7 | ||
nn.ReLU(inplace=True), | ||
) | ||
# fc8 | ||
self.classifier = nn.Linear(in_features=4096, out_features=bit) | ||
self.label_layer = nn.Linear(bit, label_dim) | ||
self._init() | ||
|
||
def _init(self): | ||
for m in self.modules(): | ||
if isinstance(m, nn.Conv2d): | ||
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') | ||
if m.bias is not None: | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.BatchNorm2d): | ||
nn.init.constant_(m.weight, 1) | ||
nn.init.constant_(m.bias, 0) | ||
elif isinstance(m, nn.Linear): | ||
nn.init.normal_(m.weight, 0, 0.01) | ||
nn.init.constant_(m.bias, 0) | ||
|
||
def forward(self, x): | ||
x = self.features(x) | ||
x = x.squeeze() | ||
x = self.classifier(x) | ||
if self.training: | ||
label = self.label_layer(x) | ||
label = torch.sigmoid(label) | ||
x = torch.tanh(x) | ||
return x, label | ||
else: | ||
return x | ||
|
||
|
||
def get_image_net(bit, label_dim, pretrain=True): | ||
model = VGG_F(bit, label_dim) | ||
if pretrain: | ||
model.init_pretrained_weights(pretrain_model) | ||
return model |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,4 +21,8 @@ QDCMH.py | |
TDH.py | ||
SSAH.py | ||
MCGDH.py | ||
RDCMH.py | ||
RDCMH.py | ||
CMHH.py | ||
CHN.py | ||
SCAHN.py | ||
GCH.py |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.