Skip to content

Commit

Permalink
modified the defaults to work better
Browse files Browse the repository at this point in the history
  • Loading branch information
HaydenFaulkner committed Aug 21, 2018
1 parent cb3d5b7 commit b899f0a
Showing 1 changed file with 59 additions and 49 deletions.
108 changes: 59 additions & 49 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,11 @@ def train(run_id,
model_name,
loss_type,
m, d, k, alpha,
n_iterations=50000,
learning_rate=1e-4,#0.001 this lr (0.001) messed up the learning very badly (doesn't learn)
n_iterations=0000,
net_learning_rate=0.0001,#0.001 this lr (0.001) messed up the learning very badly (doesn't learn)
cluster_learning_rate=0.001,
chunk_size=32,
refresh_clusters_every=500,
refresh_clusters=50,
norm_clusters=False,
calc_acc_every=100,
load_latest=True,
Expand Down Expand Up @@ -56,6 +57,10 @@ def train(run_id,
net.cuda()
cudnn.benchmark = True

# make list of cluster refresh if given an interval int
if isinstance(refresh_clusters_every, int):
refresh_clusters_every = list(range(0, n_iterations, refresh_clusters_every))

# Get initial embedding using all samples in training set
initial_reps = compute_all_reps(net, train_dataset, chunk_size)

Expand All @@ -64,30 +69,29 @@ def train(run_id,
the_loss = MagnetLoss(train_y, k, m, d, alpha=alpha, measure='euclidean')

# Initialise the embeddings/representations/clusters
print("Initialising the clusters")
the_loss.update_clusters(initial_reps)

# Setup the optimizer
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=learning_rate)
elif loss_type == "repmet":
the_loss = RepMetLoss(train_y, k, m, d, alpha=alpha, measure='euclidean')#'cosine')

# Initialise the embeddings/representations/clusters
the_loss.update_clusters(initial_reps)

# Setup the optimizer
# optimizer = torch.optim.Adam(list(net.parameters()) + [the_loss.centroids], lr=learning_rate)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=learning_rate)
optimizerb = torch.optim.Adam([the_loss.centroids], lr=0.1)
elif loss_type == "repmet2":
the_loss = RepMetLoss2(train_y, k, m, d, alpha=alpha, measure='euclidean')#'cosine')
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=net_learning_rate)
optimizerb = None
elif loss_type == "repmet" or loss_type == "repmet2":
if loss_type == "repmet":
the_loss = RepMetLoss(train_y, k, m, d, alpha=alpha, measure='euclidean') # 'cosine')
elif loss_type == "repmet2":
the_loss = RepMetLoss2(train_y, k, m, d, alpha=alpha, measure='euclidean') # 'cosine')

# Initialise the embeddings/representations/clusters
print("Initialising the clusters")
the_loss.update_clusters(initial_reps)

# Setup the optimizer
# optimizer = torch.optim.Adam(list(net.parameters()) + [the_loss.centroids], lr=learning_rate)
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=learning_rate)
optimizerb = torch.optim.Adam([the_loss.centroids], lr=0.1)#0.001)
if cluster_learning_rate < 0:
optimizer = torch.optim.Adam(list(filter(lambda p: p.requires_grad, net.parameters())) + [the_loss.centroids], lr=net_learning_rate)
optimizerb = None
else:
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=net_learning_rate)
optimizerb = torch.optim.Adam([the_loss.centroids], lr=cluster_learning_rate)

l = os.listdir(save_path)
if load_latest and len(l) > 1:
Expand All @@ -98,7 +102,7 @@ def train(run_id,

net.load_state_dict(state['state_dict'])
optimizer.load_state_dict(state['optimizer'])
if loss_type == "repmet" or loss_type == "repmet2":
if optimizerb:
optimizerb = state['optimizerb']

start_iteration = state['iteration']+1
Expand Down Expand Up @@ -178,11 +182,11 @@ def train(run_id,

# Pass the gradient and update
optimizer.zero_grad()
if loss_type == "repmet" or loss_type == "repmet2":
if optimizerb:
optimizerb.zero_grad()
batch_loss.backward()
optimizer.step()
if loss_type == "repmet" or loss_type == "repmet2":
if optimizerb:
optimizerb.step()

if norm_clusters:
Expand Down Expand Up @@ -256,7 +260,7 @@ def train(run_id,
savepath="%s/batch-clusters/i%06d%s" % (plots_path, iteration, plots_ext))

train_reps_this_iter = False
if iteration > 0 and not iteration % refresh_clusters_every:
if iteration in refresh_clusters_every:
with open(save_path+'/log.txt', 'a') as f:
f.write('Refreshing clusters')
print('Refreshing clusters')
Expand Down Expand Up @@ -378,7 +382,7 @@ def train(run_id,
'train_accs': train_accs,
'test_accs': test_accs,
}
if loss_type == "repmet" or loss_type == "repmet2":
if optimizerb:
state['optimizerb'] = optimizerb.state_dict()
torch.save(state, "%s/i%06d%s" % (save_path, iteration, '.pth'))

Expand Down Expand Up @@ -428,7 +432,7 @@ def train(run_id,
'train_accs': train_accs,
'test_accs': test_accs,
}
if loss_type == "repmet" or loss_type == "repmet2":
if optimizerb:
state['optimizerb'] = optimizerb.state_dict()
torch.save(state, "%s/i%06d%s" % (save_path, iteration, '.pth'))

Expand All @@ -437,20 +441,21 @@ def parse_args():
parser.add_argument('--run_id', required=True, help='experiment run name', default='000')
parser.add_argument('--set_name', required=True, help='dataset name', default='mnist')
parser.add_argument('--model_name', required=True, help='model name', default='mnist_default')
parser.add_argument('--loss_type', required=True, help='magnet or repmet', default='repmet')
parser.add_argument('--loss_type', required=True, help='magnet, repmet, repmet2', default='repmet2')
parser.add_argument('--m', required=True, help='number of clusters per batch', default=8, type=int)
parser.add_argument('--d', required=True, help='number of samples per cluster per batch', default=8, type=int)
parser.add_argument('--k', required=True, help='number of clusters per class', default=3, type=int)
parser.add_argument('--alpha', required=True, help='cluster margin', default=1.0, type=int)
parser.add_argument('--n_iterations', required=False, help='number of iterations to perform', default=50000, type=int)
parser.add_argument('--learning_rate', required=False, help='the learning rate', default=0.001, type=float)
parser.add_argument('--n_iterations', required=False, help='number of iterations to perform', default=5000, type=int)
parser.add_argument('--net_learning_rate', required=False, help='the learning rate for the net', default=0.0001, type=float)
parser.add_argument('--cluster_learning_rate', required=False, help='the learning rate for the modes (centroids), if -1 will use single optimiser for both net and modes', default=0.001, type=float)
parser.add_argument('--chunk_size', required=False, help='the chunk/batch size for calculating embeddings (lower for less mem)', default=32, type=int)
parser.add_argument('--refresh_clusters_every', required=False, help='refresh the clusters every ? iterations', default=500, type=int)
parser.add_argument('--calc_acc_every', required=False, help='calculate the accuracy every ? iterations', default=100, type=int)
parser.add_argument('--refresh_clusters', required=False, help='refresh the clusters every ? iterations or on these iterations (int or list or ints)', default=[0,1,2])
parser.add_argument('--calc_acc_every', required=False, help='calculate the accuracy every ? iterations', default=10, type=int)
parser.add_argument('--load_latest', required=False, help='load a model if presaved', default=True)
parser.add_argument('--save_every', required=False, help='save the model every ? iterations', default=1000, type=int)
parser.add_argument('--save_every', required=False, help='save the model every ? iterations', default=500, type=int)
parser.add_argument('--save_path', required=False, help='where to save the models', default=configs.general.paths.models)
parser.add_argument('--plot_every', required=False, help='plot graphs every ? iterations', default=500, type=int)
parser.add_argument('--plot_every', required=False, help='plot graphs every ? iterations', default=10, type=int)
parser.add_argument('--plots_path', required=False, help='where to save the plots', default=configs.general.paths.graphing)
parser.add_argument('--plots_ext', required=False, help='.png/.pdf', default='.png')
parser.add_argument('--n_plot_samples', required=False, help='plot ? samples per class', default=10, type=int)
Expand All @@ -467,9 +472,10 @@ def parse_args():
# loss_type=args.loss_type,
# m=args.m, d=args.d, k=args.k, alpha=args.alpha,
# n_iterations=args.n_iterations,
# learning_rate=args.learning_rate,
# net_learning_rate=args.net_learning_rate,
# cluster_learning_rate=args.cluster_learning_rate,
# chunk_size=args.chunk_size,
# refresh_clusters_every=args.refresh_clusters_every,
# refresh_clusters=args.refresh_clusters,
# calc_acc_every=args.calc_acc_every,
# load_latest=args.load_latest,
# save_every=args.save_every,
Expand All @@ -496,17 +502,20 @@ def parse_args():
# train('006_r50_k1_resnet18_e1024_pt_norm_clust-scaling-norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet2',
# m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=50, calc_acc_every=10, plot_every=10, n_iterations=2000,
# norm_clusters=True)
train('005_r50_k1_resnet18_e1024_pt_norm_clust-scaling-norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet', m=12, d=4, k=1,
alpha=1.0, refresh_clusters_every=50, calc_acc_every=10, plot_every=10, n_iterations=2000, norm_clusters=True)
train('005_nr_k1_resnet18_e1024_pt_norm_clust-scaling-norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet', m=12, d=4, k=1,
alpha=1.0, refresh_clusters_every=5000, calc_acc_every=10, plot_every=10, n_iterations=2000, norm_clusters=True)
train('006_nr_k1_resnet18_e1024_pt_norm_clust-scaling-norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet2', m=12, d=4, k=1,
alpha=1.0, refresh_clusters_every=5000, calc_acc_every=10, plot_every=100, n_iterations=2000, norm_clusters=True)

train('005_r50_k1_resnet18_e1024_pt_norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet', m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=50, calc_acc_every=10, plot_every=10, n_iterations=2000)
train('006_r50_k1_resnet18_e1024_pt_norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet2', m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=50, calc_acc_every=10, plot_every=10, n_iterations=2000)
train('005_nr_k1_resnet18_e1024_pt_norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet', m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=5000, calc_acc_every=10, plot_every=10, n_iterations=2000)
train('006_nr_k1_resnet18_e1024_pt_norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet2', m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=5000, calc_acc_every=10, plot_every=10, n_iterations=2000)
# train('005_r50_k1_resnet18_e1024_pt_norm_clust-scaling-norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet', m=12, d=4, k=1,
# alpha=1.0, refresh_clusters_every=50, calc_acc_every=10, plot_every=10, n_iterations=2000, norm_clusters=True)
# train('005_nr_k1_resnet18_e1024_pt_norm_clust-scaling-norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet', m=12, d=4, k=1,
# alpha=1.0, refresh_clusters_every=5000, calc_acc_every=10, plot_every=10, n_iterations=2000, norm_clusters=True)
# train('006_nr_k1_resnet18_e1024_pt_norm_clust-scaling-norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet2', m=12, d=4, k=1,
# alpha=1.0, refresh_clusters_every=5000, calc_acc_every=10, plot_every=100, n_iterations=2000, norm_clusters=True)
#
# train('005_r50_k1_resnet18_e1024_pt_norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet', m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=50, calc_acc_every=10, plot_every=10, n_iterations=2000)
# train('006_r50_k1_resnet18_e1024_pt_norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet2', m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=50, calc_acc_every=10, plot_every=10, n_iterations=2000)
# train('005_r1_k1_resnet18_e1024_pt_norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet', m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=1, calc_acc_every=10, plot_every=10, n_iterations=2000)
# train('006_r1_k1_resnet18_e1024_pt_norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet2', m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=1, calc_acc_every=10, plot_every=10, n_iterations=2000)
train('006_r0t2_k3_resnet18_e1024_pt_norm_lr.001_clust-scaling-norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet2', m=12, d=4, k=3, alpha=1.0, refresh_clusters_every=[0,1,2], calc_acc_every=10, plot_every=10, n_iterations=2000, norm_clusters=True)
# train('005_nr_k1_resnet18_e1024_pt_norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet', m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=5000, calc_acc_every=10, plot_every=10, n_iterations=2000)
# train('006_nr_k1_resnet18_e1024_pt_norm', 'oxford_flowers', 'resnet18_e1024_pt_norm', 'repmet2', m=12, d=4, k=1, alpha=1.0, refresh_clusters_every=5000, calc_acc_every=10, plot_every=10, n_iterations=2000)



Expand Down Expand Up @@ -586,8 +595,9 @@ def parse_args():
in the positive directions, at least when the emb = 2 (just so i can better vis), not sure if does when higher
dimension, it might be because in 2 dims the clusters being wrong is too much so pushes away first...
2) Refreshing the clusters would fix this (clusters moving away too much) to some extent, but it is a balancing act
3) The learning rate on the centroids for repmet also plays a role, too small (when had as one loss with net 1e-4)
and they don't move and too big and sample costs blow out into big numbers which causes 0's for the loss due to
the exponential
4)
3) The learning rate on the centroids for repmet also plays a role, 0.1 or 0.01 is too big and pushes the centroids
too far, we find 0.001 to work well
4) Normalising the centroids helps for faster convergence, and also permits more variation in learning rate
5) Cluster refresh rate inconclusive at the moment, seems necessary in first epoch or 2 but after not so... but how
much better is it running kmeans every iteration vs never?
"""

0 comments on commit b899f0a

Please sign in to comment.