Skip to content

Commit

Permalink
update model and optimize valid step.
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGodder committed Dec 8, 2019
1 parent 83f9352 commit 8aedde8
Show file tree
Hide file tree
Showing 21 changed files with 442 additions and 315 deletions.
2 changes: 1 addition & 1 deletion script/default_config.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
training:
method: GCH
method: GANCMH
dataName: Mirflickr25K
batchSize: 64
bit: 64
Expand Down
3 changes: 2 additions & 1 deletion torchcmh/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from __future__ import absolute_import
from __future__ import print_function

__version__ = '0.2.3'
__version__ = '0.2.6'
__author__ = 'Xinzhi Wang'
__description__ = 'Deep Cross Modal Hashing in PyTorch'

Expand All @@ -16,5 +16,6 @@
training,
utils,
loss,
evaluate,
run
)
32 changes: 9 additions & 23 deletions torchcmh/dataset/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,34 +6,20 @@ you need .mat file to read data. for each dataset, there are three part mat file

And you can download from follows url:

#### MS-COCO
- imageList.mat....

- tag.mat....

- label.mat....


#### Mirflckr25K
- imageList.mat....

- tag.mat....

- label.mat....
- download: https://pan.baidu.com/s/1XlomcQR2ldu5dzENbukDLw
- password: 2s36

#### Nus wide
- download: https://pan.baidu.com/s/1v6S8pRf18O0C0tdfJVeSVw
- password: 0qi0

- imageList.mat....

- tag.mat....

- label.mat....
#### MS-COCO
- download:
- password:

#### IAPR TC-12
- imageList.mat....

- tag.mat....

- label.mat....
- download:
- password:

After download these mat file, please put in data folder by right file name.
4 changes: 4 additions & 0 deletions torchcmh/evaluate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
# @Time : 2019/12/2
# @Author : Godder
# @Github : https://github.com/WangGodder
27 changes: 27 additions & 0 deletions torchcmh/evaluate/classification.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
# @Time : 2019/12/2
# @Author : Godder
# @Github : https://github.com/WangGodder
import torch
__all__ = ['multi_label_acc']


def multi_label_acc(predict: torch.Tensor, label: torch.Tensor):
"""
Multi-label acc evaluate.
acc = \frac{predict right label}{number of label}
:param predict: the predict label
:param label: the ground truth label with same shape as predict
:return: the mean accuracy of all predict instances.
"""
assert predict.shape == label.shape
label_num = torch.sum(label, dim=-1)
acc = 0
for i in range(predict.size(0)):
_, predict_ind = torch.topk(predict[i, :], int(label_num[i]))
right_num = torch.sum(label[i][predict_ind])
acc += (right_num / label_num[i]).item()
acc /= predict.size(0)
return acc


75 changes: 75 additions & 0 deletions torchcmh/evaluate/hashing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
# -*- coding: utf-8 -*-
# @Time : 2019/12/2
# @Author : Godder
# @Github : https://github.com/WangGodder
import torch
__all__ = ['calc_map_k', 'calc_precisions_topn']


def calc_hammingDist(B1, B2):
q = B2.shape[1]
if len(B1.shape) < 2:
B1 = B1.unsqueeze(0)
distH = 0.5 * (q - B1.mm(B2.transpose(0, 1)))
return distH


def calc_map_k(qB, rB, query_L, retrieval_L, k=None):
# qB: {-1,+1}^{mxq}
# rB: {-1,+1}^{nxq}
# query_L: {0,1}^{mxl}
# retrieval_L: {0,1}^{nxl}
num_query = query_L.shape[0]
qB = torch.sign(qB)
rB = torch.sign(rB)
map = 0
if k is None:
k = retrieval_L.shape[0]
for iter in range(num_query):
q_L = query_L[iter]
if len(q_L.shape) < 2:
q_L = q_L.unsqueeze(0) # [1, hash length]
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
tsum = torch.sum(gnd)
if tsum == 0:
continue
hamm = calc_hammingDist(qB[iter, :], rB)
_, ind = torch.sort(hamm)
ind.squeeze_()
gnd = gnd[ind]
total = min(k, int(tsum))
count = torch.arange(1, total + 1).type(torch.float32)
tindex = torch.nonzero(gnd)[:total].squeeze().type(torch.float32) + 1.0
if tindex.is_cuda:
count = count.cuda()
map = map + torch.mean(count / tindex)
map = map / num_query
return map


def calc_precisions_topn(qB, rB, query_L, retrieval_L, recall_gas=0.02, num_retrieval=10000):
qB = qB.float()
rB = rB.float()
qB = torch.sign(qB - 0.5)
rB = torch.sign(rB - 0.5)
num_query = query_L.shape[0]
# num_retrieval = retrieval_L.shape[0]
precisions = [0] * int(1 / recall_gas)
for iter in range(num_query):
q_L = query_L[iter]
if len(q_L.shape) < 2:
q_L = q_L.unsqueeze(0) # [1, hash length]
gnd = (q_L.mm(retrieval_L.transpose(0, 1)) > 0).squeeze().type(torch.float32)
hamm = calc_hammingDist(qB[iter, :], rB)
_, ind = torch.sort(hamm)
ind.squeeze_()
gnd = gnd[ind]
for i, recall in enumerate(np.arange(recall_gas, 1 + recall_gas, recall_gas)):
total = int(num_retrieval * recall)
right = torch.nonzero(gnd[: total]).squeeze().numpy()
# right_num = torch.nonzero(gnd[: total]).squeeze().shape[0]
right_num = right.size
precisions[i] += (right_num/total)
for i in range(len(precisions)):
precisions[i] /= num_query
return precisions
2 changes: 1 addition & 1 deletion torchcmh/loss/common_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def focal_loss(logit: torch.Tensor, gamma, alpha=1, eps=1e-5):
:param eps: a tiny value prevent log(0)
:return:
"""
return alpha * -torch.pow(1 - logit, gamma) * torch.log(logit + eps)
return -alpha * torch.pow(1 - logit, gamma) * torch.log(logit + eps)


def cosine(hash1: torch.Tensor, hash2: torch.Tensor):
Expand Down
59 changes: 59 additions & 0 deletions torchcmh/models/MLP.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
# @Time : 2019/7/22
# @Author : Godder
# @Github : https://github.com/WangGodder
from torch import nn
from torch.nn import functional as F
from torchcmh.models import BasicModule
import torch


__all__ = ['MLP']


def weights_init(m):
if type(m) == nn.Conv2d:
nn.init.normal_(m.weight.data, 0.0, 0.01)
nn.init.normal_(m.bias.data, 0.0, 0.01)
elif type(m) == nn.Conv1d:
nn.init.normal_(m.weight.data, 0.0, 0.01)
nn.init.normal_(m.bias.data, 0.0, 0.01)


class MLP(BasicModule):
def __init__(self, input_dim, output_dim, hidden_nodes=[8192], dropout=None, leakRelu=True):
"""
:param input_dim: dimension of input
:param output_dim: bit number of the final binary code
"""
super(MLP, self).__init__()
self.module_name = "MLP"

# full-conv layers
full_conv_layers = []
in_channel = 1
for hidden_node in hidden_nodes:
kernel_size = input_dim if in_channel == 1 else 1
full_conv_layers.append(nn.Conv1d(in_channel, hidden_node, kernel_size=kernel_size, stride=1))
in_channel = hidden_node
if dropout:
full_conv_layers.append(nn.Dropout(dropout))
if leakRelu:
full_conv_layers.append(nn.LeakyReLU(inplace=True))
else:
full_conv_layers.append(nn.ReLU(inplace=True))
full_conv_layers.append(nn.Conv1d(in_channel, output_dim, kernel_size=1, stride=1))
self.layers = nn.Sequential(*full_conv_layers)
# self.conv1 = nn.Conv2d(1, hidden_node, kernel_size=(input_dim, 1), stride=(1, 1))
# self.dropout = nn.Dropout(dropout) if dropout else None
# self.conv2 = nn.Conv2d(hidden_node, bit, kernel_size=1, stride=(1, 1))
self.apply(weights_init)

def forward(self, x: torch.Tensor):
if len(x.shape) == 2:
x = x.unsqueeze(1)
if len(x.shape) > 3:
x = x.squeeze().unsqueeze(1)
x = self.layers(x)
x = x.squeeze()
return x
21 changes: 8 additions & 13 deletions torchcmh/models/MSText.py → torchcmh/models/MSBlockText.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,10 +87,9 @@ def _init_weight(self):


class MS_Text(BasicModule):
def __init__(self, txt_length, bit, **kwargs):
def __init__(self, txt_length, bit):
super(MS_Text, self).__init__()
self.module_name = 'MS Text'
self.out_feature = kwargs['out_feature'] if 'out_feature' in kwargs.keys() else False
self.module_name = 'MS-BlockTextNet'

# MS blocks
self.block1 = MS_Block(1, 1, 10, txt_length)
Expand Down Expand Up @@ -159,33 +158,29 @@ def _get_feature(self, x):
x = self.release4(x)
return x

def _get_hash(self, x):
def _get_hash(self, x, out_feature=False):
x = self._get_feature(x)

f = self.linear_conv4(x)
h = self.relu(f)
h = self.LRN(h)
h = self.hash_conv4(h)
# h = torch.tanh(h)
h = h.squeeze() # type: torch.Tensor
if self.out_feature and self.training:
if out_feature:
f = f.squeeze()
return h, f

return h

def forward(self, x):
def forward(self, x, out_feature=False):
ms_out = self._get_ms_feature(x)
return self._get_hash(ms_out)
return self._get_hash(ms_out, out_feature)


def get_MS_Text(tag_length, bit, **kwargs):
return MS_Text(tag_length, bit, **kwargs)
def get_MS_Block_Text(tag_length, bit):
return MS_Text(tag_length, bit)


if __name__ == '__main__':
net = MS_Text(1386, 64)
x = torch.randn(4, 1, 1386, 1)
y = net(x)
print(y.shape)
# print(y1)
Loading

0 comments on commit 8aedde8

Please sign in to comment.