Skip to content

Commit

Permalink
update readme file and some training methods.
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGodder committed Nov 21, 2019
1 parent 0ccbb24 commit fc4467e
Show file tree
Hide file tree
Showing 16 changed files with 426 additions and 48 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,6 @@ if __name__ == '__main__':
run(config_path='default_config.yml')
```
### how to create your method
- create new method file in folder ./torchcmh/training/
- inherit implement TrainBase
- change the method name in config .yml file and run.
4 changes: 2 additions & 2 deletions script/default_config.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
training:
method: DCMH
method: GCH
dataName: Mirflickr25K
batchSize: 64
bit: 64
Expand Down Expand Up @@ -42,7 +42,7 @@ dataPreprocess:
onehot: True

dataAugmentation:
enable: True
enable: False
img:
enable: True
originalRetention: 0.2
Expand Down
2 changes: 1 addition & 1 deletion torchcmh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import absolute_import
from __future__ import print_function

__version__ = '0.2.1'
__version__ = '0.2.3'
__author__ = 'Xinzhi Wang'
__description__ = 'Deep Cross Modal Hashing in PyTorch'

Expand Down
31 changes: 31 additions & 0 deletions torchcmh/loss/common_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# -*- coding: utf-8 -*-
# @Time : 2019/10/28
# @Author : Godder
# @Github : https://github.com/WangGodder
import torch
__all__ = ['focal_loss', 'cosine', 'vector_length']


def focal_loss(logit: torch.Tensor, gamma, alpha=1, eps=1e-5):
"""
focal loss: /alpha * (1 - logit)^{gamma} * log^{logit}
:param logit: logit value 0 ~ 1
:param gamma:
:param alpha:
:param eps: a tiny value prevent log(0)
:return:
"""
return alpha * -torch.pow(1 - logit, gamma) * torch.log(logit + eps)


def cosine(hash1: torch.Tensor, hash2: torch.Tensor):
inter = torch.matmul(hash1, hash2.t())
length1 = vector_length(hash1, keepdim=True)
length2 = vector_length(hash2, keepdim=True)
return torch.div(inter, torch.matmul(length1, length2.t()))


def vector_length(vector: torch.Tensor, keepdim=False):
if len(vector.shape) > 2:
vector = vector.unsqueeze(0)
return torch.sqrt(torch.sum(torch.pow(vector, 2), dim=-1, keepdim=keepdim))
33 changes: 16 additions & 17 deletions torchcmh/loss/distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,7 @@
# @Author : Godder
# @Github : https://github.com/WangGodder
import torch
__all__ = ['focal_loss', 'hamming_dist', 'euclidean_dist']


def focal_loss(logit: torch.Tensor, gamma, alpha=1, eps=1e-5):
"""
focal loss: /alpha * (1 - logit)^{gamma} * log^{logit}
:param logit: logit value 0 ~ 1
:param gamma:
:param alpha:
:param eps: a tiny value prevent log(0)
:return:
"""
return alpha * -torch.pow(1 - logit, gamma) * torch.log(logit + eps)
__all__ = ['hamming_dist', 'euclidean_dist_matrix', 'euclidean_dist']


def hamming_dist(hash1, hash2):
Expand All @@ -30,12 +18,12 @@ def hamming_dist(hash1, hash2):
return distH


def euclidean_dist(tensor1: torch.Tensor, tensor2: torch.Tensor):
def euclidean_dist_matrix(tensor1: torch.Tensor, tensor2: torch.Tensor):
"""
calculate euclidean distance as inner product
:param tensor1:
:param tensor2:
:return:
:param tensor1: a tensor with shape (a, c)
:param tensor2: a tensor with shape (b, c)
:return: the euclidean distance matrix which each point is the distance between a row in tensor1 and a row in tensor2.
"""
dim1 = tensor1.shape[0]
dim2 = tensor2.shape[0]
Expand All @@ -45,3 +33,14 @@ def euclidean_dist(tensor1: torch.Tensor, tensor2: torch.Tensor):
dist = torch.sqrt(a2 + b2 - 2 * multi)
return dist


def euclidean_dist(tensor1:torch.Tensor, tensor2:torch.Tensor):
"""
calculate euclidean distance between two list of vector.
:param tensor1: tensor with shape (a, b)
:param tensor2: tensor with shape (a, b)
:return:
"""
sub = tensor1 - tensor2
dist = torch.sqrt(torch.sum(torch.pow(sub, 2), dim=1))
return dist
1 change: 1 addition & 0 deletions torchcmh/models/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -11,4 +11,5 @@ MSText_capsnet.py
MSText_norm.py
pcb.py
resnet_gcn.py
resnet_norm.py
pretrain_model
34 changes: 34 additions & 0 deletions torchcmh/models/GCH/GC_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# -*- coding: utf-8 -*-
# @Time : 2019/10/29
# @Author : Godder
# @Github : https://github.com/WangGodder
from torchcmh.models import BasicModule
from torchcmh.models.GCN import GraphConvolution
from torch import nn
import torch


class GCN(BasicModule):
hidden_layer_dim = 1024

def __init__(self, bit, label_dim):
super(GCN, self).__init__()
self.module_name = "GCH.GCN"
self.gcn1 = GraphConvolution(bit, self.hidden_layer_dim)
self.gcn2 = GraphConvolution(self.hidden_layer_dim, bit)
self.hash_layer = nn.Linear(bit, bit)
self.label_layer = nn.Linear(bit, label_dim)

def forward(self, x, adjacent_matrix):
feature = self.gcn1(x, adjacent_matrix)
feature = torch.relu_(feature)
feature = self.gcn2(feature, adjacent_matrix)
feature = torch.sigmoid(feature)
# hash represent
hash_represent = self.hash_layer(feature)
# label generation
label = self.label_layer(hash_represent)
# activate
hash_represent = torch.tanh(hash_represent)
label = torch.sigmoid(label)
return hash_represent, label
4 changes: 4 additions & 0 deletions torchcmh/models/GCH/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Time : 2019/10/29
# @Author : Godder
# @Github : https://github.com/WangGodder
46 changes: 46 additions & 0 deletions torchcmh/models/GCH/enbedding_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
# @Time : 2019/10/29
# @Author : Godder
# @Github : https://github.com/WangGodder
from torchcmh.models import BasicModule
from torch import nn
import torch
__all__ = ['get_label_net', 'get_txt_net']


class EmbeddingNet(BasicModule):
hidden_dim1 = 512
hidden_dim2 = 512

def __init__(self, embedding_dim, bit, label_dim):
super(EmbeddingNet, self).__init__()
fc1 = nn.Linear(embedding_dim, self.hidden_dim1)
fc2 = nn.Linear(self.hidden_dim1, self.hidden_dim2)
fc3 = nn.Linear(self.hidden_dim2, bit)
relu = nn.ReLU(inplace=True)
self.fc = nn.Sequential(fc1, relu, fc2, relu, fc3)
self.label_layer = nn.Linear(bit, label_dim)

def forward(self, x: torch.Tensor):
if len(x.shape) > 2:
x = x.squeeze()
hash_represent = self.fc(x)
if self.training:
label_generation = self.label_layer(hash_represent)
hash_represent = torch.tanh(hash_represent)
label_generation = torch.sigmoid(label_generation)
return hash_represent, label_generation
else:
return hash_represent


def get_label_net(label_dim, bit):
label_net = EmbeddingNet(label_dim, bit, label_dim)
label_net.module_name = "GCH.LabelNet"
return label_net


def get_txt_net(txt_dim, bit, label_dim):
txt_net = EmbeddingNet(txt_dim, bit, label_dim)
txt_net.module_name = "GCH.TextNet"
return txt_net
94 changes: 94 additions & 0 deletions torchcmh/models/GCH/image_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
# -*- coding: utf-8 -*-
# @Time : 2019/10/29
# @Author : Godder
# @Github : https://github.com/WangGodder
import torch.nn as nn
import torch
from torchcmh.models import abs_dir, BasicModule
import os
__all__ = ['get_image_net']

pretrain_model = os.path.join(abs_dir, "pretrain_model", "imagenet-vgg-f.pth")


class VGG_F(BasicModule):
def __init__(self, bit, label_dim):
super(VGG_F, self).__init__()
self.module_name = "GCH.vgg-f"
self.features = nn.Sequential(
# 0 conv1
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=11, stride=4),
# 1 relu1
nn.ReLU(inplace=True),
# 2 norm1
nn.LocalResponseNorm(size=2, k=2),
# 3 pool1
nn.ZeroPad2d((0, 1, 0, 1)),
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# 4 conv2
nn.Conv2d(in_channels=64, out_channels=256, kernel_size=5, stride=1, padding=2),
# 5 relu2
nn.ReLU(inplace=True),
# 6 norm2
nn.LocalResponseNorm(size=2, k=2),
# 7 pool2
nn.MaxPool2d(kernel_size=(3, 3), stride=2),
# 8 conv3
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
# 9 relu3
nn.ReLU(inplace=True),
# 10 conv4
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
# 11 relu4
nn.ReLU(inplace=True),
# 12 conv5
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
# 13 relu5
nn.ReLU(inplace=True),
# 14 pool5
nn.MaxPool2d(kernel_size=(3, 3), stride=(2, 2), padding=(0, 0)),
# 15 full_conv6
nn.Conv2d(in_channels=256, out_channels=4096, kernel_size=6),
# 16 relu6
nn.ReLU(inplace=True),
# 17 full_conv7
nn.Conv2d(in_channels=4096, out_channels=4096, kernel_size=1),
# 18 relu7
nn.ReLU(inplace=True),
)
# fc8
self.classifier = nn.Linear(in_features=4096, out_features=bit)
self.label_layer = nn.Linear(bit, label_dim)
self._init()

def _init(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)

def forward(self, x):
x = self.features(x)
x = x.squeeze()
x = self.classifier(x)
if self.training:
label = self.label_layer(x)
label = torch.sigmoid(label)
x = torch.tanh(x)
return x, label
else:
return x


def get_image_net(bit, label_dim, pretrain=True):
model = VGG_F(bit, label_dim)
if pretrain:
model.init_pretrained_weights(pretrain_model)
return model
6 changes: 5 additions & 1 deletion torchcmh/training/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,8 @@ QDCMH.py
TDH.py
SSAH.py
MCGDH.py
RDCMH.py
RDCMH.py
CMHH.py
CHN.py
SCAHN.py
GCH.py
7 changes: 4 additions & 3 deletions torchcmh/training/CMHH.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
from torch.optim import SGD
from tqdm import tqdm
from torch.utils.data import DataLoader
from torchcmh.models import alexnet, mlp, vgg_f
from torchcmh.models import mlp, vgg_f
from torchcmh.training.base import TrainBase
from torchcmh.utils import calc_neighbor
from torchcmh.dataset.utils import single_data
from torchcmh.loss.distance import focal_loss, euclidean_dist
from torchcmh.loss.distance import euclidean_dist_matrix
from torchcmh.loss.common_loss import focal_loss


class CMHH(TrainBase):
Expand Down Expand Up @@ -101,7 +102,7 @@ def train(self, num_works=4):
self.plotter.next_epoch()

def object_function(self, cur_h, O, label, ind):
hamming_dist = euclidean_dist(cur_h, O)
hamming_dist = euclidean_dist_matrix(cur_h, O)
logit = torch.exp(-hamming_dist * self.parameters['beta'])
sim = calc_neighbor(label, self.train_label)
focal_pos = sim * focal_loss(logit, gamma=self.parameters['gamma'], alpha=self.parameters['alpha'])
Expand Down
Loading

0 comments on commit fc4467e

Please sign in to comment.