Skip to content

Commit

Permalink
update readme and MLP model.
Browse files Browse the repository at this point in the history
  • Loading branch information
WangGodder committed Dec 18, 2019
1 parent 61bbcfe commit c848913
Show file tree
Hide file tree
Showing 6 changed files with 30 additions and 16 deletions.
8 changes: 7 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# torch cross modal (torchcmh)
# torch cross modal hashing (torchcmh)

torchcmh is a library built on PyTorch for deep learning cross modal hashing.\
if you want use the dataset i make, please download mat file and image file by readme file in dataset package.
Expand All @@ -11,14 +11,20 @@ you need to install these package to run
- create a configuration file as ./script/default_config.yml
```yaml
training:
# the name of python file in training
method: DCMH
# the data set name, you can choose mirflickr25k, nus wide, ms coco, iapr tc-12
dataName: Mirflickr25K
batchSize: 64
# the bit of hash codes
bit: 64
# if true, the program will be run on gpu
cuda: True
# the device id you want to use, if you want to multi gpu, you can use [id1, id2]
device: 0
datasetPath:
Mirflickr25k:
# the path you download the image of data set.
img_dir: I:\dataset\mirflickr25k\mirflickr

```
Expand Down
18 changes: 9 additions & 9 deletions torchcmh/dataset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
def triplet_data(dataset_name: str, img_dir: str, **kwargs):
if dataset_name.lower() == 'mirflickr25k':
from torchcmh.dataset.mirflckr25k import get_triplet_datasets
elif dataset_name.lower() == 'nus wide':
elif dataset_name.lower() in ['nus wide', 'nuswide']:
from torchcmh.dataset.nus_wide import get_triplet_datasets
elif dataset_name.lower() == 'coco2014':
elif dataset_name.lower() in ['coco2014', 'coco', 'mscoco', 'ms coco']:
from torchcmh.dataset.coco2014 import get_triplet_datasets
elif dataset_name.lower() == 'iapr tc-12':
elif dataset_name.lower() in ['iapr tc-12', 'iapr', 'tc-12', 'tc12']:
from torchcmh.dataset.tc12 import get_triplet_datasets
else:
raise ValueError("there is no dataset name is %s" % dataset_name)
Expand All @@ -25,11 +25,11 @@ def triplet_data(dataset_name: str, img_dir: str, **kwargs):
def single_data(dataset_name: str, img_dir: str, **kwargs):
if dataset_name.lower() == 'mirflickr25k':
from torchcmh.dataset.mirflckr25k import get_single_datasets
elif dataset_name.lower() == 'nus wide':
elif dataset_name.lower() in ['nus wide', 'nuswide']:
from torchcmh.dataset.nus_wide import get_single_datasets
elif dataset_name.lower() == 'coco2014':
elif dataset_name.lower() in ['coco2014', 'coco', 'mscoco', 'ms coco']:
from torchcmh.dataset.coco2014 import get_single_datasets
elif dataset_name.lower() == 'iapr tc-12':
elif dataset_name.lower() in ['iapr tc-12', 'iapr', 'tc-12', 'tc12']:
from torchcmh.dataset.tc12 import get_single_datasets
else:
raise ValueError("there is no dataset name is %s" % dataset_name)
Expand All @@ -40,11 +40,11 @@ def single_data(dataset_name: str, img_dir: str, **kwargs):
def pairwise_data(dataset_name: str, img_dir: str, **kwargs):
if dataset_name.lower() == 'mirflickr25k':
from torchcmh.dataset.mirflckr25k import get_pairwise_datasets
elif dataset_name.lower() == 'nus wide':
elif dataset_name.lower() in ['nus wide', 'nuswide']:
from torchcmh.dataset.nus_wide import get_pairwise_datasets
elif dataset_name.lower() == 'coco2014':
elif dataset_name.lower() in ['coco2014', 'coco', 'mscoco', 'ms coco']:
from torchcmh.dataset.coco2014 import get_pairwise_datasets
elif dataset_name.lower() == 'iapr tc-12':
elif dataset_name.lower() in ['iapr tc-12', 'iapr', 'tc-12', 'tc12']:
from torchcmh.dataset.tc12 import get_pairwise_datasets
else:
raise ValueError("there is no dataset name is %s" % dataset_name)
Expand Down
5 changes: 4 additions & 1 deletion torchcmh/models/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,8 @@ MSText_capsnet.py
MSText_norm.py
pcb.py
resnet_gcn.py
resnet_norm.py
pretrain_model
ASCHN
ACMH
GCH
LCMH
4 changes: 3 additions & 1 deletion torchcmh/models/MLP.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def weights_init(m):


class MLP(BasicModule):
def __init__(self, input_dim, output_dim, hidden_nodes=[8192], dropout=None, leakRelu=False):
def __init__(self, input_dim, output_dim, hidden_nodes=[8192], dropout=None, leakRelu=False, norm=False):
"""
:param input_dim: dimension of input
:param output_dim: bit number of the final binary code
Expand All @@ -35,6 +35,8 @@ def __init__(self, input_dim, output_dim, hidden_nodes=[8192], dropout=None, lea
kernel_size = input_dim if in_channel == 1 else 1
full_conv_layers.append(nn.Conv1d(in_channel, hidden_node, kernel_size=kernel_size, stride=1))
in_channel = hidden_node
if norm:
full_conv_layers.append(nn.BatchNorm1d(hidden_node))
if dropout:
full_conv_layers.append(nn.Dropout(dropout))
if leakRelu:
Expand Down
5 changes: 2 additions & 3 deletions torchcmh/models/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def __init__(self, num_classes, loss, block, layers,
dropout_p=None,
**kwargs):
self.inplanes = 64
self.out_feature = kwargs['out_feature'] if 'out_feature' in kwargs.keys() else False
super(ResNet, self).__init__()
self.loss = loss
self.feature_dim = 512 * block.expansion
Expand Down Expand Up @@ -207,7 +206,7 @@ def featuremaps(self, x):
x = self.layer4(x)
return x

def forward(self, x):
def forward(self, x, out_feature=False):
f = self.featuremaps(x)
v = self.global_avgpool(f)
v = v.view(v.size(0), -1)
Expand All @@ -216,7 +215,7 @@ def forward(self, x):
v = self.fc(v)

y = self.classifier(v)
if self.out_feature and self.training:
if out_feature:
return y, v
return y

Expand Down
6 changes: 5 additions & 1 deletion torchcmh/utils/plotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@

class VisdomLinePlotter(object):
"""Plots to Visdom"""
def __init__(self, env_name='challenge'):
def __init__(self, env_name='plotter'):
self.viz = Visdom()
self.env = env_name
self.plots = {}
self.epoch = 0

def plot(self, var_name, split_name, y, x=None):
if x is None:
x = self.epoch
Expand All @@ -25,9 +26,12 @@ def plot(self, var_name, split_name, y, x=None):
))
else:
self.viz.line(X=np.array([x]), Y=np.array([y]), env=self.env, win=self.plots[var_name], name=split_name, update='append')

def next_epoch(self):
self.epoch += 1

def reset_epoch(self):
self.epoch = 0

def get_plotter(env_name: str):
return VisdomLinePlotter(env_name)
Expand Down

0 comments on commit c848913

Please sign in to comment.