Skip to content

Commit

Permalink
Merge pull request #104 from clovaai/distributed
Browse files Browse the repository at this point in the history
Distributed
  • Loading branch information
joonson authored May 2, 2021
2 parents ec285fb + 15d82d1 commit a0466aa
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 76 deletions.
16 changes: 9 additions & 7 deletions DatasetLoader.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def __getitem__(self, indices):
if self.augment:
augtype = random.randint(0,4)
if augtype == 1:
audio = self.augment_wav.reverberate(audio)
audio = self.augment_wav.reverberate(audio)
elif augtype == 2:
audio = self.augment_wav.additive_noise('music',audio)
elif augtype == 3:
Expand Down Expand Up @@ -192,7 +192,7 @@ def __init__(self, data_source, nPerSpeaker, max_seg_per_spk, batch_size, distri
self.batch_size = batch_size;
self.epoch = 0;
self.seed = seed;
self.distributed = distributed
self.distributed = distributed;

def __iter__(self):

Expand Down Expand Up @@ -235,7 +235,7 @@ def __iter__(self):

## Prevent two pairs of the same speaker in the same batch
for ii in mixid:
startbatch = len(mixlabel) - len(mixlabel) % self.batch_size
startbatch = round_down(len(mixlabel), self.batch_size)
if flattened_label[ii] not in mixlabel[startbatch:]:
mixlabel.append(flattened_label[ii])
mixmap.append(ii)
Expand All @@ -244,17 +244,19 @@ def __iter__(self):

## Divide data to each GPU
if self.distributed:
total_size = len(mixed_list) - len(mixed_list) % (self.batch_size * dist.get_world_size())
total_size = round_down(len(mixed_list), self.batch_size * dist.get_world_size())
start_index = int ( ( dist.get_rank() ) / dist.get_world_size() * total_size )
end_index = int ( ( dist.get_rank() + 1 ) / dist.get_world_size() * total_size )
self.num_samples = end_index - start_index
return iter(mixed_list[start_index:end_index])
else:
total_size = len(mixed_list) - len(mixed_list) % self.batch_size
total_size = round_down(len(mixed_list), self.batch_size)
self.num_samples = total_size
return iter(mixed_list[:total_size])


def __len__(self):
return len(self.data_source)
def __len__(self) -> int:
return self.num_samples

def set_epoch(self, epoch: int) -> None:
self.epoch = epoch
Expand Down
6 changes: 2 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,11 +88,9 @@ You can add new models and loss functions to `models` and `loss` directories res

- Use `--distributed` flag to enable distributed training.

- GPU indices should be set using the command `export CUDA_VISIBLE_DEVICES=0,1,2,3`.
- GPU indices should be set before training using the command `export CUDA_VISIBLE_DEVICES=0,1,2,3`.

- Evaluation is not performed between epochs during training.

- If you are running more than one distributed training session, you need to change the port.
- If you are running more than one distributed training session, you need to change the `--port` argument.

### Data

Expand Down
85 changes: 55 additions & 30 deletions SpeakerNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
import torch.nn.functional as F
import numpy, math, pdb, sys, random
import time, os, itertools, shutil, importlib

from tuneThreshold import tuneThresholdfromScore
from DatasetLoader import test_dataset_loader

from torch.cuda.amp import autocast, GradScaler

class WrappedModel(nn.Module):
Expand Down Expand Up @@ -115,13 +115,11 @@ def train_network(self, loader, verbose):
counter += 1;
index += stepsize;



telapsed = time.time() - tstart
tstart = time.time()

if verbose:
sys.stdout.write("\rProcessing ({:d}) ".format(index));
sys.stdout.write("\rProcessing {:d} of {:d}:".format(index, loader.__len__()*loader.batch_size));
sys.stdout.write("Loss {:f} TEER/TAcc {:2.3f}% - {:.2f} Hz ".format(loss/counter, top1/counter, stepsize/telapsed));
sys.stdout.flush();

Expand All @@ -136,7 +134,12 @@ def train_network(self, loader, verbose):
## Evaluate from list
## ===== ===== ===== ===== ===== ===== ===== =====

def evaluateFromList(self, test_list, test_path, nDataLoaderThread, print_interval=100, num_eval=10, **kwargs):
def evaluateFromList(self, test_list, test_path, nDataLoaderThread, distributed, print_interval=100, num_eval=10, **kwargs):

if distributed:
rank = torch.distributed.get_rank()
else:
rank = 0

self.__model__.eval();

Expand All @@ -150,63 +153,85 @@ def evaluateFromList(self, test_list, test_path, nDataLoaderThread, print_interv
lines = f.readlines()

## Get a list of unique file names
files = sum([x.strip().split()[-2:] for x in lines],[])
files = list(itertools.chain(*[x.strip().split()[-2:] for x in lines]))
setfiles = list(set(files))
setfiles.sort()

## Define test data loader
test_dataset = test_dataset_loader(setfiles, test_path, num_eval=num_eval, **kwargs)

if distributed:
sampler = torch.utils.data.distributed.DistributedSampler(test_dataset, shuffle=False)
else:
sampler = None

test_loader = torch.utils.data.DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=nDataLoaderThread,
drop_last=False,
sampler=sampler
)

## Extract features for every image
for idx, data in enumerate(test_loader):
inp1 = data[0][0].cuda()
ref_feat = self.__model__(inp1).detach().cpu()
with torch.no_grad():
ref_feat = self.__model__(inp1).detach().cpu()
feats[data[1][0]] = ref_feat
telapsed = time.time() - tstart

if idx % print_interval == 0:
sys.stdout.write("\rReading {:d} of {:d}: {:.2f} Hz, embedding size {:d}".format(idx,len(setfiles),idx/telapsed,ref_feat.size()[1]));
if idx % print_interval == 0 and rank == 0:
sys.stdout.write("\rReading {:d} of {:d}: {:.2f} Hz, embedding size {:d}".format(idx,test_loader.__len__(),idx/telapsed,ref_feat.size()[1]));

print('')
all_scores = [];
all_labels = [];
all_trials = [];
tstart = time.time()

## Read files and compute all scores
for idx, line in enumerate(lines):
if distributed:
## Gather features from all GPUs
feats_all = [None for _ in range(0,torch.distributed.get_world_size())]
torch.distributed.all_gather_object(feats_all, feats)

data = line.split();
if rank == 0:

## Append random label if missing
if len(data) == 2: data = [random.randint(0,1)] + data
tstart = time.time()
print('')

ref_feat = feats[data[1]].cuda()
com_feat = feats[data[2]].cuda()
## Combine gathered features
if distributed:
feats = feats_all[0]
for feats_batch in feats_all[1:]:
feats.update(feats_batch)

if self.__model__.module.__L__.test_normalize:
ref_feat = F.normalize(ref_feat, p=2, dim=1)
com_feat = F.normalize(com_feat, p=2, dim=1)
## Read files and compute all scores
for idx, line in enumerate(lines):

dist = F.pairwise_distance(ref_feat.unsqueeze(-1), com_feat.unsqueeze(-1).transpose(0,2)).detach().cpu().numpy();
data = line.split();

score = -1 * numpy.mean(dist);
## Append random label if missing
if len(data) == 2: data = [random.randint(0,1)] + data

all_scores.append(score);
all_labels.append(int(data[0]));
all_trials.append(data[1]+" "+data[2])
ref_feat = feats[data[1]].cuda()
com_feat = feats[data[2]].cuda()

if idx % print_interval == 0:
telapsed = time.time() - tstart
sys.stdout.write("\rComputing {:d} of {:d}: {:.2f} Hz".format(idx,len(lines),idx/telapsed));
sys.stdout.flush();
if self.__model__.module.__L__.test_normalize:
ref_feat = F.normalize(ref_feat, p=2, dim=1)
com_feat = F.normalize(com_feat, p=2, dim=1)

dist = F.pairwise_distance(ref_feat.unsqueeze(-1), com_feat.unsqueeze(-1).transpose(0,2)).detach().cpu().numpy();

score = -1 * numpy.mean(dist);

all_scores.append(score);
all_labels.append(int(data[0]));
all_trials.append(data[1]+" "+data[2])

if idx % print_interval == 0:
telapsed = time.time() - tstart
sys.stdout.write("\rComputing {:d} of {:d}: {:.2f} Hz".format(idx,len(lines),idx/telapsed));
sys.stdout.flush();

return (all_scores, all_labels, all_trials);

Expand Down
95 changes: 60 additions & 35 deletions trainSpeakerNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import torch
import glob
import zipfile
import warnings
import datetime
from tuneThreshold import *
from SpeakerNet import *
Expand Down Expand Up @@ -52,6 +53,11 @@
parser.add_argument('--nPerSpeaker', type=int, default=1, help='Number of utterances per speaker per batch, only for metric learning based losses');
parser.add_argument('--nClasses', type=int, default=5994, help='Number of speakers in the softmax layer, only for softmax-based losses');

## Evaluation parameters
parser.add_argument('--dcf_p_target', type=float, default=0.05, help='A priori probability of the specified target speaker');
parser.add_argument('--dcf_c_miss', type=float, default=1, help='Cost of a missed detection');
parser.add_argument('--dcf_c_fa', type=float, default=1, help='Cost of a spurious detection');

## Load and save
parser.add_argument('--initial_model', type=str, default="", help='Initial model weights');
parser.add_argument('--save_path', type=str, default="exps/exp1", help='Path for model and logs');
Expand Down Expand Up @@ -107,6 +113,7 @@ def find_option_type(key, parser):
except:
pass;

warnings.simplefilter("ignore")

## ===== ===== ===== ===== ===== ===== ===== =====
## Trainer script
Expand Down Expand Up @@ -136,9 +143,11 @@ def main_worker(gpu, ngpus_per_node, args):
s = WrappedModel(s).cuda(args.gpu)

it = 1
eers = [100];

## Write args to scorefile
scorefile = open(args.result_save_path+"/scores.txt", "a+");
if args.gpu == 0:
## Write args to scorefile
scorefile = open(args.result_save_path+"/scores.txt", "a+");

## Initialise trainer and data loader
train_dataset = train_dataset_loader(**vars(args))
Expand All @@ -162,13 +171,13 @@ def main_worker(gpu, ngpus_per_node, args):
modelfiles = glob.glob('%s/model0*.model'%args.model_save_path)
modelfiles.sort()

if len(modelfiles) >= 1:
if(args.initial_model != ""):
trainer.loadParameters(args.initial_model);
print("Model {} loaded!".format(args.initial_model));
elif len(modelfiles) >= 1:
trainer.loadParameters(modelfiles[-1]);
print("Model {} loaded from previous state!".format(modelfiles[-1]));
it = int(os.path.splitext(os.path.basename(modelfiles[-1]))[0][5:]) + 1
elif(args.initial_model != ""):
trainer.loadParameters(args.initial_model);
print("Model {} loaded!".format(args.initial_model));

for ii in range(1,it):
trainer.__scheduler__.step()
Expand All @@ -180,21 +189,29 @@ def main_worker(gpu, ngpus_per_node, args):

print('Total parameters: ',pytorch_total_params)
print('Test list',args.test_list)

assert args.distributed == False


sc, lab, _ = trainer.evaluateFromList(**vars(args))
result = tuneThresholdfromScore(sc, lab, [1, 0.1]);

p_target = 0.05
c_miss = 1
c_fa = 1
if args.gpu == 0:

result = tuneThresholdfromScore(sc, lab, [1, 0.1]);

fnrs, fprs, thresholds = ComputeErrorRates(sc, lab)
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, args.dcf_p_target, args.dcf_c_miss, args.dcf_c_fa)

print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "VEER {:2.4f}".format(result[1]), "MinDCF {:2.5f}".format(mindcf));

if ("nsml" in sys.modules) and args.gpu == 0:
training_report = {};
training_report["summary"] = True;
training_report["epoch"] = it;
training_report["step"] = it;
training_report["val_eer"] = result[1];
training_report["val_dcf"] = mindcf;

fnrs, fprs, thresholds = ComputeErrorRates(sc, lab)
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, p_target, c_miss, c_fa)
nsml.report(**training_report);

print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "VEER {:2.4f}".format(result[1]), "VEER {:2.5f}".format(mindcf));
quit();
return

## Save training code and params
if args.gpu == 0:
Expand All @@ -218,34 +235,45 @@ def main_worker(gpu, ngpus_per_node, args):

loss, traineer = trainer.train_network(train_loader, verbose=(args.gpu == 0));

if it % args.test_interval == 0 and args.gpu == 0:
if args.gpu == 0:
print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "Epoch {:d}, TEER/TAcc {:2.2f}, TLOSS {:f}, LR {:f}".format(it, traineer, loss, max(clr)));
scorefile.write("Epoch {:d}, TEER/TAcc {:2.2f}, TLOSS {:f}, LR {:f} \n".format(it, traineer, loss, max(clr)));

## Perform evaluation only in single GPU training
if not args.distributed:
sc, lab, _ = trainer.evaluateFromList(**vars(args))
if it % args.test_interval == 0:

sc, lab, _ = trainer.evaluateFromList(**vars(args))

if args.gpu == 0:

result = tuneThresholdfromScore(sc, lab, [1, 0.1]);

print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "Epoch {:d}, VEER {:2.4f}".format(it, result[1]));
scorefile.write("Epoch {:d}, VEER {:2.4f}\n".format(it, result[1]));
fnrs, fprs, thresholds = ComputeErrorRates(sc, lab)
mindcf, threshold = ComputeMinDcf(fnrs, fprs, thresholds, args.dcf_p_target, args.dcf_c_miss, args.dcf_c_fa)

trainer.saveParameters(args.model_save_path+"/model%09d.model"%it);
eers.append(result[1])

if args.gpu == 0:
print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "Epoch {:d}, TEER/TAcc {:2.2f}, TLOSS {:f}, LR {:f}".format(it, traineer, loss, max(clr)));
scorefile.write("Epoch {:d}, TEER/TAcc {:2.2f}, TLOSS {:f}, LR {:f} \n".format(it, traineer, loss, max(clr)));
print('\n',time.strftime("%Y-%m-%d %H:%M:%S"), "Epoch {:d}, VEER {:2.4f}, MinDCF {:2.5f}".format(it, result[1], mindcf));
scorefile.write("Epoch {:d}, VEER {:2.4f}, MinDCF {:2.5f}\n".format(it, result[1], mindcf));

trainer.saveParameters(args.model_save_path+"/model%09d.model"%it);

scorefile.flush()
with open(args.model_save_path+"/model%09d.eer"%it, 'w') as eerfile:
eerfile.write('{:2.4f}'.format(result[1]))

scorefile.flush()

if ("nsml" in sys.modules) and args.gpu == 0:
training_report = {};
training_report["summary"] = True;
training_report["epoch"] = it;
training_report["step"] = it;
training_report["train_loss"] = loss;
training_report["min_eer"] = min(eers);

nsml.report(**training_report);

scorefile.close();
if args.gpu == 0:
scorefile.close();


## ===== ===== ===== ===== ===== ===== ===== =====
Expand All @@ -255,18 +283,15 @@ def main_worker(gpu, ngpus_per_node, args):

def main():

if ("nsml" in sys.modules):
if ("nsml" in sys.modules) and not args.eval:
args.save_path = os.path.join(args.save_path,SESSION_NAME.replace('/','_'))

args.model_save_path = args.save_path+"/model"
args.result_save_path = args.save_path+"/result"
args.feat_save_path = ""

if not(os.path.exists(args.model_save_path)):
os.makedirs(args.model_save_path)

if not(os.path.exists(args.result_save_path)):
os.makedirs(args.result_save_path)
os.makedirs(args.model_save_path, exist_ok=True)
os.makedirs(args.result_save_path, exist_ok=True)

n_gpus = torch.cuda.device_count()

Expand Down

0 comments on commit a0466aa

Please sign in to comment.