Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
panxiao.94 committed May 20, 2021
0 parents commit a2c918c
Show file tree
Hide file tree
Showing 16 changed files with 1,485 additions and 0 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
.idea
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Contrastive Learning for Many-to-many Multilingual Neural Machine Transaltion(mCOLT), ACL2021
The code for training mCOLT, a multilingual NMT training framework, implemented based on fairseq.


44 changes: 44 additions & 0 deletions examples/configs/parallel_mono_12e12d_contrastive.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
model_dir: model/pretrain/lab/multilingual/l2r/multi_bpe32k/parallel_mono_contrastive_1/transformer_big_t2t_12e12d
data_1: data/multilingual/bin/merged_deduped_ras
data_mono_1: data/multilingual/bin/mono_only/splitaa
data_mono_2: data/multilingual/bin/mono_only/splitab
data_mono_3: data/multilingual/bin/mono_only/splitac
data_mono_4: data/multilingual/bin/mono_only/splitad
data_mono_5: data/multilingual/bin/mono_only/splitae
data_mono_6: data/multilingual/bin/mono_only/mono_de_fr_en
data_mono_7: data/multilingual/bin/mono_only/mono_nl_pl_pt
source_lang: src
target_lang: trg
task: translation_w_mono
parallel_ratio: 0.2
mono_ratio: 0.07
arch: transformer_big_t2t_12e12d
share_all_embeddings: true
encoder_learned_pos: true
decoder_learned_pos: true
max_source_positions: 1024
max_target_positions: 1024
dropout: 0.1
criterion: label_smoothed_cross_entropy_with_contrastive
contrastive_lambda: 1.0
temperature: 0.1
lr: 0.0003
clip_norm: 10.0
optimizer: adam
adam_eps: 1e-06
weight_decay: 0.01
warmup_updates: 10000
label_smoothing: 0.1
lr_scheduler: polynomial_decay
min_lr: -1
max_tokens: 1536
update_freq: 30
max_update: 5000000
no_scale_embedding: true
layernorm_embedding: true
save_interval_updates: 2000
skip_invalid_size_inputs_valid_test: true
log_interval: 500
num_workers: 1
fp16: true
seed: 33122
4 changes: 4 additions & 0 deletions mcolt/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from .arches import *
from . criterions import *
from .data import *
from .tasks import *
1 change: 1 addition & 0 deletions mcolt/arches/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .transformer import *
380 changes: 380 additions & 0 deletions mcolt/arches/transformer.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions mcolt/criterions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .label_smoothed_cross_entropy_with_contrastive import *
123 changes: 123 additions & 0 deletions mcolt/criterions/label_smoothed_cross_entropy_with_contrastive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import math

from fairseq.criterions import register_criterion
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion
from fairseq import metrics, utils

from collections import deque

import torch
import torch.nn as nn


@register_criterion("label_smoothed_cross_entropy_with_contrastive")
class LabelSmoothedCrossEntropyCriterionWithContrastive(
LabelSmoothedCrossEntropyCriterion
):
def __init__(self, task, sentence_avg, label_smoothing, ignore_prefix_size=0, report_accuracy=False,
contrastive_lambda=0.0,
temperature=1.0):
super().__init__(task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy)
self.contrastive_lambda = contrastive_lambda
self.temperature = temperature

@staticmethod
def add_args(parser):
LabelSmoothedCrossEntropyCriterion.add_args(parser)
parser.add_argument("--contrastive-lambda", type=float,
default=0.0,
help="The contrastive loss weight")
parser.add_argument("--temperature", type=float,
default=1.0,)

def swap_sample(self, sample):
target = sample["target"]
prev_output_tokens = sample["net_input"]["prev_output_tokens"]
src_tokens = torch.cat((prev_output_tokens[:, :1], sample["net_input"]['src_tokens']), dim=-1)
return {
"net_input": {
"src_tokens": target.contiguous(),
"src_lengths": (target != self.padding_idx).int().sum(dim=1),
"prev_output_tokens": src_tokens[:, :-1].contiguous()
},
'nsentences': sample['nsentences'],
'ntokens': utils.item((src_tokens[:, 1:] != self.padding_idx).int().sum().data),
"target": src_tokens[:, 1:].contiguous(),
"id": sample["id"],
}

def forward(self, model, sample, reduce=True):
net_output = model(**sample["net_input"])
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce)
encoder_out = model.encoder.forward(sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]).encoder_out
reverse_sample = self.swap_sample(sample)
reversed_encoder_out = model.encoder.forward(reverse_sample["net_input"]["src_tokens"], reverse_sample["net_input"]["src_lengths"]).encoder_out
contrastive_loss = self.get_contrastive_loss(
encoder_out,
reversed_encoder_out,
sample,
reverse_sample,
)
sample_size = (
sample["target"].size(0) if self.sentence_avg else sample["ntokens"]
)
nsentences = sample["target"].size(0)
ntokens = sample["ntokens"]
all_loss = loss + contrastive_loss * self.contrastive_lambda * ntokens / nsentences
logging_output = {
"loss": loss.data,
"nll_loss": nll_loss.data,
"ntokens": ntokens,
"nsentences": nsentences,
"sample_size": sample_size,
}
if isinstance(contrastive_loss, int):
logging_output["contrastive_loss"] = 0
else:
logging_output["contrastive_loss"] = utils.item(contrastive_loss.data)

return all_loss, sample_size, logging_output

def similarity_function(self, ):
return nn.CosineSimilarity(dim=-1)

def get_contrastive_loss(self, encoder_out1, encoder_out2, sample1, sample2):

def _sentence_embedding(encoder_out, sample):
encoder_output = encoder_out.transpose(0, 1)
src_tokens = sample["net_input"]["src_tokens"]
mask = (src_tokens != self.padding_idx)
encoder_embedding = (encoder_output * mask.unsqueeze(-1)).sum(dim=1) / mask.float().sum(dim=1).unsqueeze(-1) # [batch, hidden_size]
return encoder_embedding

encoder_embedding1 = _sentence_embedding(encoder_out1, sample1) # [batch, hidden_size]
encoder_embedding2 = _sentence_embedding(encoder_out2, sample2) # [batch, hidden_size]

batch_size = encoder_embedding2.shape[0]
feature_dim = encoder_embedding2.shape[1]
anchor_feature = encoder_embedding1
contrast_feature = encoder_embedding2

similarity_function = self.similarity_function()
anchor_dot_contrast = similarity_function(anchor_feature.expand((batch_size, batch_size, feature_dim)),
torch.transpose(contrast_feature.expand((batch_size, batch_size, feature_dim)), 0, 1))

loss = -nn.LogSoftmax(0)(torch.div(anchor_dot_contrast, self.temperature)).diag().sum()

return loss

@classmethod
def reduce_metrics(cls, logging_outputs) -> None:
super().reduce_metrics(logging_outputs)
nsentences = utils.item(
sum(log.get("nsentences", 0) for log in logging_outputs)
)
contrastive_loss = utils.item(
sum(log.get("contrastive_loss", 0) for log in logging_outputs)
)
metrics.log_scalar(
"contrastive_loss",
contrastive_loss / nsentences / math.log(2),
nsentences,
round=3,
)
1 change: 1 addition & 0 deletions mcolt/data/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .subsample_language_pair_dataset import SubsampleLanguagePairDataset
124 changes: 124 additions & 0 deletions mcolt/data/subsample_language_pair_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
from fairseq.data import BaseWrapperDataset, LanguagePairDataset, plasma_utils
import numpy as np

import logging

logger = logging.getLogger(__name__)


class SubsampleLanguagePairDataset(BaseWrapperDataset):
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples
Args:
dataset (~torch.utils.data.Dataset): dataset to subsample
size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive)
"""

def __init__(self, dataset, size_ratio, weights=None, replace=False, seed=0, epoch=1):
super().__init__(dataset)
assert size_ratio <= 1
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int)
logger.info(
"subsampled dataset from {} to {} (ratio={})".format(
len(self.dataset), self.actual_size, size_ratio
)
)
self.src_dict = self.dataset.src_dict
self.tgt_dict = self.dataset.tgt_dict
self.left_pad_source = self.dataset.left_pad_source
self.left_pad_target = self.dataset.left_pad_target
self.seed = seed
self._cur_epoch = None
self._cur_indices = None
self.replace = replace
if weights is None:
self.weights = None
else:
assert len(weights) == len(dataset)
weights_arr = np.array(weights, dtype=np.float64)
weights_arr /= weights_arr.sum()
self.weights = plasma_utils.PlasmaArray(weights_arr)
self.set_epoch(epoch)

def __getitem__(self, index):
index = self._cur_indices.array[index]
return self.dataset.__getitem__(index)

def __len__(self):
return self.actual_size

@property
def sizes(self):
return self.dataset.sizes[self._cur_indices.array]

@property
def src_sizes(self):
return self.dataset.src_sizes[self._cur_indices.array]

@property
def tgt_sizes(self):
return self.dataset.tgt_sizes[self._cur_indices.array]

@property
def name(self):
return self.dataset.name

def num_tokens(self, index):
index = self._cur_indices.array[index]
return self.dataset.num_tokens(index)

def size(self, index):
index = self._cur_indices.array[index]
return self.dataset.size(index)

def ordered_indices(self):
if self.shuffle:
indices = np.random.permutation(len(self)).astype(np.int64)
else:
indices = np.arange(len(self), dtype=np.int64)
# sort by target length, then source length
if self.tgt_sizes is not None:
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")]
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")]

def prefetch(self, indices):
indices = self._cur_indices.array[indices]
self.dataset.prefetch(indices)

@property
def can_reuse_epoch_itr_across_epochs(self):
return False

def set_epoch(self, epoch):
logger.info("SubsampleLanguagePairDataset.set_epoch: {}".format(epoch))
super().set_epoch(epoch)

if epoch == self._cur_epoch:
return

self._cur_epoch = epoch

# Generate a weighted sample of indices as a function of the
# random seed and the current epoch.

rng = np.random.RandomState(
[
42, # magic number
self.seed % (2 ** 32), # global seed
self._cur_epoch, # epoch index
]
)
self._cur_indices = plasma_utils.PlasmaArray(
rng.choice(
len(self.dataset),
self.actual_size,
replace=self.replace,
p=(None if self.weights is None else self.weights.array),
)
)

logger.info(
"Dataset is sub-sampled: {} -> {}, first 3 ids are: {}".format(len(self.dataset), self.actual_size,
",".join(
[str(_i) for _i in
self._cur_indices.array[:3]])))
2 changes: 2 additions & 0 deletions mcolt/tasks/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .translation_w_mono import *
from .translation_w_langtok import *
Loading

0 comments on commit a2c918c

Please sign in to comment.