Skip to content

Commit

Permalink
add cgct
Browse files Browse the repository at this point in the history
  • Loading branch information
Evgeneus committed Jun 2, 2021
1 parent 7995a13 commit 16b345f
Show file tree
Hide file tree
Showing 5 changed files with 309 additions and 7 deletions.
14 changes: 14 additions & 0 deletions readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,20 @@ python src/main.py \
--lambda_node 0.3 \
--output_dir 'office31-dcgct/dslr_rest/CDAN'
```
```
python src/main_cgct.py \
--method 'CDAN' \
--encoder 'ResNet50' \
--dataset 'office31' \
--data_root [your office31 folder] \
--source 'dslr' \
--target 'webcam' 'amazon' \
--source_iters 100 \
--adapt_iters 3000 \
--finetune_iters 15000 \
--lambda_node 0.1 \
--output_dir 'office31-cgct/dslr_rest/CDAN'
```

## Office-Home
```
Expand Down
118 changes: 118 additions & 0 deletions src/main_cgct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import os
import random
import argparse
import torch
import numpy as np
from torch.utils.data import DataLoader

import graph_net
import utils
import trainer
import networks
import preprocess


DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

parser = argparse.ArgumentParser(description='Graph Curriculum Domain Adaptaion')
# model args
parser.add_argument('--method', type=str, default='CDAN', choices=['CDAN', 'CDAN+E'])
parser.add_argument('--encoder', type=str, default='ResNet50', choices=['ResNet18', 'ResNet50'])
parser.add_argument('--rand_proj', type=int, default=1024, help='random projection dimension')
parser.add_argument('--edge_features', type=int, default=128, help='graph edge features dimension')
parser.add_argument('--save_models', action='store_true', help='whether to save encoder, mlp and gnn models')
# dataset args
parser.add_argument('--dataset', type=str, default='office31', choices=['office31', 'office-home', 'pacs',
'domain-net'], help='dataset used')
parser.add_argument('--source', default='amazon', help='name of source domain')
parser.add_argument('--target', nargs='+', default=['dslr', 'webcam'], help='names of target domains')
parser.add_argument('--data_root', type=str, default='data/office31', help='path to dataset root')
# training args
parser.add_argument('--source_iters', type=int, default=100, help='number of source pre-train iters')
parser.add_argument('--adapt_iters', type=int, default=3000, help='number of iters for a curriculum adaptation')
parser.add_argument('--finetune_iters', type=int, default=1000, help='number of fine-tuning iters')
parser.add_argument('--test_interval', type=int, default=500, help='interval of two continuous test phase')
parser.add_argument('--output_dir', type=str, default='res', help='output directory')
parser.add_argument('--source_batch', type=int, default=32)
parser.add_argument('--target_batch', type=int, default=32)
# optimization args
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--wd', type=float, default=0.0005, help='weight decay')
parser.add_argument('--lambda_edge', default=1., type=float, help='edge loss weight')
parser.add_argument('--lambda_node', default=0.3, type=float, help='node classification loss weight')
parser.add_argument('--lambda_adv', default=1.0, type=float, help='adversarial loss weight')
parser.add_argument('--threshold', type=float, default=0.7, help='threshold for pseudo labels')
parser.add_argument('--seed', type=int, default=0, help='random seed for training')
parser.add_argument('--num_workers', type=int, default=4, help='number of workers for dataloaders')


def main(args):
# fix random seed
random.seed(args.seed)
os.environ['PYTHONHASHSEED'] = str(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

# create train configurations
args.use_cgct_mask = True # used in CGCT for pseudo label mask in target datasets
config = utils.build_config(args)
# prepare data
dsets, dset_loaders = utils.build_data(config)
# set base network
net_config = config['encoder']
base_network = net_config["name"](**net_config["params"])
base_network = base_network.to(DEVICE)
print(base_network)
# set GNN classifier
classifier_gnn = graph_net.ClassifierGNN(in_features=base_network.bottleneck.out_features,
edge_features=config['edge_features'],
nclasses=base_network.fc.out_features,
device=DEVICE)
classifier_gnn = classifier_gnn.to(DEVICE)
print(classifier_gnn)

# train on source domain
base_network, classifier_gnn = trainer.train_source(config, base_network, classifier_gnn, dset_loaders)

# create random layer and adversarial network
class_num = config['encoder']['params']['class_num']
random_layer = networks.RandomLayer([base_network.output_num(), class_num], config['random_dim'], DEVICE)
adv_net = networks.AdversarialNetwork(config['random_dim'], config['random_dim'], config['ndomains'])
random_layer = random_layer.to(DEVICE)
adv_net = adv_net.to(DEVICE)
print(random_layer)
print(adv_net)

# run adaptation episodes
for curri_iter in range(len(config['data']['target']['name'])):
print('Starting the adaptation...')
######## Step 1: train one adaptation episod on combined target domains ##########
target_train_datasets = preprocess.ConcatDataset(dsets['target_train'].values())
dset_loaders['target_train'] = DataLoader(dataset=target_train_datasets,
batch_size=config['data']['target']['batch_size'],
shuffle=True, num_workers=config['num_workers'],
drop_last=True)

base_network, classifier_gnn = trainer.adapt_target_cgct(config, base_network, classifier_gnn,
dset_loaders, random_layer, adv_net)

######### Step 2: obtain the target pseudo labels and upgrade target domains ##########
trainer.upgrade_target_domains(config, dsets, dset_loaders, base_network, classifier_gnn, curri_iter)

######### Step 3: fine-tuning stage ###########
config['source_iters'] = config['finetune_iters']
base_network, classifier_gnn = trainer.train_source(config, base_network, classifier_gnn, dset_loaders)
print('Finished training and evaluation!')

# save models
if args.save_models:
torch.save(base_network.cpu().state_dict(), os.path.join(config['output_path'], 'base_network.pth.tar'))
torch.save(classifier_gnn.cpu().state_dict(), os.path.join(config['output_path'], 'classifier_gnn.pth.tar'))


if __name__ == "__main__":
args = parser.parse_args()
main(args)
23 changes: 18 additions & 5 deletions src/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,25 +123,36 @@ def cummulative_sizes(self):

class ImageList(Dataset):
def __init__(self, image_root, image_list_root, dataset, domain_label, dataset_name, split='train', transform=None,
sample_masks=None, pseudo_labels=None):
sample_masks=None, pseudo_labels=None, use_cgct_mask=False):
self.image_root = image_root
self.dataset = dataset # name of the domain
self.dataset_name = dataset_name # name of whole dataset
self.transform = transform
self.loader = self._rgb_loader
self.sample_masks = sample_masks
self.pseudo_labels = pseudo_labels
self.use_cgct_mask = use_cgct_mask
if dataset_name == 'domain-net':
imgs = self._make_dataset(os.path.join(image_list_root, dataset + '_' + split + '.txt'), domain_label)
else:
imgs = self._make_dataset(os.path.join(image_list_root, dataset + '.txt'), domain_label)
self.imgs = imgs
if sample_masks is not None:
temp_list = self.imgs
self.imgs = [temp_list[i] for i in self.sample_masks]

# CGCT and D-CGCT use different mask type for pseudo labels in target data
# D-CGCT retrieves samples with high confidence from target data and discard others
# CGCT keeps all samples but uses the mask for high confidence samples
if self.use_cgct_mask:
self.sample_masks = sample_masks if sample_masks is not None else torch.zeros(len(self.imgs)).float()
if pseudo_labels is not None:
self.labels = self.pseudo_labels[self.sample_masks]
self.labels = self.pseudo_labels
assert len(self.labels) == len(self.imgs), 'Lengths do no match!'
else:
if sample_masks is not None:
temp_list = self.imgs
self.imgs = [temp_list[i] for i in self.sample_masks]
if pseudo_labels is not None:
self.labels = self.pseudo_labels[self.sample_masks]
assert len(self.labels) == len(self.imgs), 'Lengths do no match!'

def _rgb_loader(self, path):
with open(path, 'rb') as f:
Expand Down Expand Up @@ -172,6 +183,8 @@ def __getitem__(self, index):
output['target'] = torch.squeeze(torch.LongTensor([np.int64(target).item()]))
output['domain'] = domain
output['idx'] = index
if self.use_cgct_mask:
output['mask'] = torch.squeeze(torch.LongTensor([np.int64(self.sample_masks[index]).item()]))

return output

Expand Down
153 changes: 153 additions & 0 deletions src/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,9 @@ def eval_domain(config, test_loader, base_network, classifier_gnn):
logits_gnn_all.append(logits_gnn.cpu())
confidences_gnn_all.append(nn.Softmax(dim=1)(logits_gnn_all[-1]).max(1)[0])
labels_all.append(data['target'])
# TODO:
if i == 1:
break
# concatenate data
logits_mlp = torch.cat(logits_mlp_all, dim=0)
logits_gnn = torch.cat(logits_gnn_all, dim=0)
Expand Down Expand Up @@ -83,6 +86,7 @@ def eval_domain(config, test_loader, base_network, classifier_gnn):
'confidences_gnn': confidences_gnn,
'pred_cls': predict_gnn.numpy(),
'sample_masks': sample_masks_idx,
'sample_masks_cgct': sample_masks_bool.float(),
'pseudo_label_acc': pseudo_label_acc,
'correct_pseudo_labels': correct_pseudo_labels,
'total_pseudo_labels': total_pseudo_labels,
Expand Down Expand Up @@ -276,6 +280,106 @@ def adapt_target(config, base_network, classifier_gnn, dset_loaders, max_inherit
return base_network, classifier_gnn


def adapt_target_cgct(config, base_network, classifier_gnn, dset_loaders, random_layer, adv_net):
# define loss functions
criterion_gedge = nn.BCELoss(reduction='mean')
ce_criterion = nn.CrossEntropyLoss()

# configure optimizer
optimizer_config = config['optimizer']
parameter_list = base_network.get_parameters() + adv_net.get_parameters() \
+ [{'params': classifier_gnn.parameters(), 'lr_mult': 10, 'decay_mult': 2}]
optimizer = optimizer_config['type'](parameter_list, **(optimizer_config['optim_params']))
# configure learning rates
param_lr = []
for param_group in optimizer.param_groups:
param_lr.append(param_group['lr'])
schedule_param = optimizer_config['lr_param']

# start train loop
len_train_source = len(dset_loaders['source'])
len_train_target = len(dset_loaders['target_train'])
# set nets in train mode
base_network.train()
classifier_gnn.train()
adv_net.train()
random_layer.train()
for i in range(config['adapt_iters']):
optimizer = utils.inv_lr_scheduler(optimizer, i, **schedule_param)
optimizer.zero_grad()
# get input data
if i % len_train_source == 0:
iter_source = iter(dset_loaders['source'])
if i % len_train_target == 0:
iter_target = iter(dset_loaders['target_train'])

batch_source = iter_source.next()
batch_target = iter_target.next()
inputs_source, inputs_target = batch_source['img'].to(DEVICE), batch_target['img'].to(DEVICE)
labels_source, labels_target = batch_source['target'].to(DEVICE), batch_target['target'].to(DEVICE)
mask_target = batch_target['mask'].bool().to(DEVICE)
domain_source, domain_target = batch_source['domain'].to(DEVICE), batch_target['domain'].to(DEVICE)
domain_input = torch.cat([domain_source, domain_target], dim=0)

# make forward pass for encoder and mlp head
features_source, logits_mlp_source = base_network(inputs_source)
features_target, logits_mlp_target = base_network(inputs_target)
features = torch.cat((features_source, features_target), dim=0)
logits_mlp = torch.cat((logits_mlp_source, logits_mlp_target), dim=0)
softmax_mlp = nn.Softmax(dim=1)(logits_mlp)
# ce loss for MLP head
mlp_loss = ce_criterion(torch.cat((logits_mlp_source, logits_mlp_target[mask_target]), dim=0),
torch.cat((labels_source, labels_target[mask_target]), dim=0))

# *** GNN at work ***
# make forward pass for gnn head
logits_gnn, edge_sim = classifier_gnn(features)
# compute pseudo-labels for affinity matrix by mlp classifier
out_target_class = torch.softmax(logits_mlp_target, dim=1)
target_score, target_pseudo_labels = out_target_class.max(1, keepdim=True)
idx_pseudo = target_score > config['threshold']
idx_pseudo = mask_target.unsqueeze(1) | idx_pseudo
target_pseudo_labels[~idx_pseudo] = classifier_gnn.mask_val
# combine source labels and target pseudo labels for edge_net
node_labels = torch.cat((labels_source, target_pseudo_labels.squeeze(dim=1)), dim=0).unsqueeze(dim=0)
# compute source-target mask and ground truth for edge_net
edge_gt, edge_mask = classifier_gnn.label2edge(node_labels)
# compute edge loss
edge_loss = criterion_gedge(edge_sim.masked_select(edge_mask), edge_gt.masked_select(edge_mask))
# ce loss for GNN head
gnn_loss = ce_criterion(classifier_gnn(torch.cat((features_source, features_target[mask_target]), dim=0))[0],
torch.cat((labels_source, labels_target[mask_target]), dim=0))

# *** Adversarial net at work ***
if config['method'] == 'CDAN+E':
entropy = transfer_loss.Entropy(softmax_mlp)
trans_loss = transfer_loss.CDAN(config['ndomains'], [features, softmax_mlp], adv_net,
entropy, networks.calc_coeff(i), random_layer, domain_input)
elif config['method'] == 'CDAN':
trans_loss = transfer_loss.CDAN(config['ndomains'], [features, softmax_mlp],
adv_net, None, None, random_layer, domain_input)
else:
raise ValueError('Method cannot be recognized.')

# total loss and backpropagation
loss = config['lambda_adv'] * trans_loss + mlp_loss + \
config['lambda_node'] * gnn_loss + config['lambda_edge'] * edge_loss
loss.backward()
optimizer.step()
# printout train loss
if i % 20 == 0 or i == config['adapt_iters'] - 1:
log_str = 'Iters:(%4d/%d)\tMLP loss: %.4f\t GNN Loss: %.4f\t Edge Loss: %.4f\t Transfer loss:%.4f' % (
i, config["adapt_iters"], mlp_loss.item(), config['lambda_node'] * gnn_loss.item(),
config['lambda_edge'] * edge_loss.item(), config['lambda_adv'] * trans_loss.item()
)
utils.write_logs(config, log_str)
# evaluate network every test_interval
if i % config['test_interval'] == config['test_interval'] - 1:
evaluate(i, config, base_network, classifier_gnn, dset_loaders['target_test'])

return base_network, classifier_gnn


def upgrade_source_domain(config, max_inherit_domain, dsets, dset_loaders, base_network, classifier_gnn):
target_dataset = ImageList(image_root=config['data_root'], image_list_root=config['data']['image_list_root'],
dataset=max_inherit_domain, transform=config['prep']['test'], domain_label=0,
Expand Down Expand Up @@ -309,3 +413,52 @@ def upgrade_source_domain(config, max_inherit_domain, dsets, dset_loaders, base_
shuffle=True, num_workers=config['num_workers'],
drop_last=True, pin_memory=True)


def upgrade_target_domains(config, dsets, dset_loaders, base_network, classifier_gnn, curri_iter):
target_dsets_new = {}
for target_domain in dsets['target_train']:
target_dataset = ImageList(image_root=config['data_root'], image_list_root=config['data']['image_list_root'],
dataset=target_domain, transform=config['prep']['test'], domain_label=1,
dataset_name=config['dataset'], split='train')
target_loader = DataLoader(target_dataset, batch_size=config['data']['test']['batch_size'],
num_workers=config['num_workers'], drop_last=False)
# set networks to eval mode
base_network.eval()
classifier_gnn.eval()
test_res = eval_domain(config, target_loader, base_network, classifier_gnn)

# print out logs for domain
log_str = 'Adding pseudo labels of dataset: %s\tPseudo-label acc: %.4f (%d/%d)\t Total samples: %d' \
% (target_domain, test_res['pseudo_label_acc'] * 100., test_res['correct_pseudo_labels'],
test_res['total_pseudo_labels'], len(target_loader.dataset))
config["out_file"].write(str(log_str) + '\n\n')
config["out_file"].flush()
print(log_str + '\n')

# update pseudo labels
target_dataset_new = ImageList(image_root=config['data_root'],
image_list_root=config['data']['image_list_root'],
dataset=target_domain, transform=config['prep']['target'],
domain_label=1, dataset_name=config['dataset'], split='train',
sample_masks=test_res['sample_masks_cgct'],
pseudo_labels=test_res['pred_cls'], use_cgct_mask=True)
target_dsets_new[target_domain] = target_dataset_new

if curri_iter == config['ndomains'] - 1:
# sub sample the dataset with the chosen confident pseudo labels
target_dataset_new = ImageList(image_root=config['data_root'],
image_list_root=config['data']['image_list_root'],
dataset=target_domain, transform=config['prep']['source'],
domain_label=0, dataset_name=config['dataset'], split='train',
sample_masks=test_res['sample_masks'],
pseudo_labels=test_res['pred_cls'], use_cgct_mask=False)

# append to the existing source list
dsets['source'] = ConcatDataset((dsets['source'], target_dataset_new))
dsets['target_train'] = target_dsets_new

if curri_iter == config['ndomains'] - 1:
# create new source dataloader
dset_loaders['source'] = DataLoader(dsets['source'], batch_size=config['data']['source']['batch_size'] * 2,
shuffle=True, num_workers=config['num_workers'],
drop_last=True, pin_memory=True)
Loading

0 comments on commit 16b345f

Please sign in to comment.