-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
panxiao.94
committed
May 20, 2021
0 parents
commit a2c918c
Showing
16 changed files
with
1,485 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
.idea |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
# Contrastive Learning for Many-to-many Multilingual Neural Machine Transaltion(mCOLT), ACL2021 | ||
The code for training mCOLT, a multilingual NMT training framework, implemented based on fairseq. | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
model_dir: model/pretrain/lab/multilingual/l2r/multi_bpe32k/parallel_mono_contrastive_1/transformer_big_t2t_12e12d | ||
data_1: data/multilingual/bin/merged_deduped_ras | ||
data_mono_1: data/multilingual/bin/mono_only/splitaa | ||
data_mono_2: data/multilingual/bin/mono_only/splitab | ||
data_mono_3: data/multilingual/bin/mono_only/splitac | ||
data_mono_4: data/multilingual/bin/mono_only/splitad | ||
data_mono_5: data/multilingual/bin/mono_only/splitae | ||
data_mono_6: data/multilingual/bin/mono_only/mono_de_fr_en | ||
data_mono_7: data/multilingual/bin/mono_only/mono_nl_pl_pt | ||
source_lang: src | ||
target_lang: trg | ||
task: translation_w_mono | ||
parallel_ratio: 0.2 | ||
mono_ratio: 0.07 | ||
arch: transformer_big_t2t_12e12d | ||
share_all_embeddings: true | ||
encoder_learned_pos: true | ||
decoder_learned_pos: true | ||
max_source_positions: 1024 | ||
max_target_positions: 1024 | ||
dropout: 0.1 | ||
criterion: label_smoothed_cross_entropy_with_contrastive | ||
contrastive_lambda: 1.0 | ||
temperature: 0.1 | ||
lr: 0.0003 | ||
clip_norm: 10.0 | ||
optimizer: adam | ||
adam_eps: 1e-06 | ||
weight_decay: 0.01 | ||
warmup_updates: 10000 | ||
label_smoothing: 0.1 | ||
lr_scheduler: polynomial_decay | ||
min_lr: -1 | ||
max_tokens: 1536 | ||
update_freq: 30 | ||
max_update: 5000000 | ||
no_scale_embedding: true | ||
layernorm_embedding: true | ||
save_interval_updates: 2000 | ||
skip_invalid_size_inputs_valid_test: true | ||
log_interval: 500 | ||
num_workers: 1 | ||
fp16: true | ||
seed: 33122 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .arches import * | ||
from . criterions import * | ||
from .data import * | ||
from .tasks import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .transformer import * |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .label_smoothed_cross_entropy_with_contrastive import * |
123 changes: 123 additions & 0 deletions
123
mcolt/criterions/label_smoothed_cross_entropy_with_contrastive.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
import math | ||
|
||
from fairseq.criterions import register_criterion | ||
from fairseq.criterions.label_smoothed_cross_entropy import LabelSmoothedCrossEntropyCriterion | ||
from fairseq import metrics, utils | ||
|
||
from collections import deque | ||
|
||
import torch | ||
import torch.nn as nn | ||
|
||
|
||
@register_criterion("label_smoothed_cross_entropy_with_contrastive") | ||
class LabelSmoothedCrossEntropyCriterionWithContrastive( | ||
LabelSmoothedCrossEntropyCriterion | ||
): | ||
def __init__(self, task, sentence_avg, label_smoothing, ignore_prefix_size=0, report_accuracy=False, | ||
contrastive_lambda=0.0, | ||
temperature=1.0): | ||
super().__init__(task, sentence_avg, label_smoothing, ignore_prefix_size, report_accuracy) | ||
self.contrastive_lambda = contrastive_lambda | ||
self.temperature = temperature | ||
|
||
@staticmethod | ||
def add_args(parser): | ||
LabelSmoothedCrossEntropyCriterion.add_args(parser) | ||
parser.add_argument("--contrastive-lambda", type=float, | ||
default=0.0, | ||
help="The contrastive loss weight") | ||
parser.add_argument("--temperature", type=float, | ||
default=1.0,) | ||
|
||
def swap_sample(self, sample): | ||
target = sample["target"] | ||
prev_output_tokens = sample["net_input"]["prev_output_tokens"] | ||
src_tokens = torch.cat((prev_output_tokens[:, :1], sample["net_input"]['src_tokens']), dim=-1) | ||
return { | ||
"net_input": { | ||
"src_tokens": target.contiguous(), | ||
"src_lengths": (target != self.padding_idx).int().sum(dim=1), | ||
"prev_output_tokens": src_tokens[:, :-1].contiguous() | ||
}, | ||
'nsentences': sample['nsentences'], | ||
'ntokens': utils.item((src_tokens[:, 1:] != self.padding_idx).int().sum().data), | ||
"target": src_tokens[:, 1:].contiguous(), | ||
"id": sample["id"], | ||
} | ||
|
||
def forward(self, model, sample, reduce=True): | ||
net_output = model(**sample["net_input"]) | ||
loss, nll_loss = self.compute_loss(model, net_output, sample, reduce=reduce) | ||
encoder_out = model.encoder.forward(sample["net_input"]["src_tokens"], sample["net_input"]["src_lengths"]).encoder_out | ||
reverse_sample = self.swap_sample(sample) | ||
reversed_encoder_out = model.encoder.forward(reverse_sample["net_input"]["src_tokens"], reverse_sample["net_input"]["src_lengths"]).encoder_out | ||
contrastive_loss = self.get_contrastive_loss( | ||
encoder_out, | ||
reversed_encoder_out, | ||
sample, | ||
reverse_sample, | ||
) | ||
sample_size = ( | ||
sample["target"].size(0) if self.sentence_avg else sample["ntokens"] | ||
) | ||
nsentences = sample["target"].size(0) | ||
ntokens = sample["ntokens"] | ||
all_loss = loss + contrastive_loss * self.contrastive_lambda * ntokens / nsentences | ||
logging_output = { | ||
"loss": loss.data, | ||
"nll_loss": nll_loss.data, | ||
"ntokens": ntokens, | ||
"nsentences": nsentences, | ||
"sample_size": sample_size, | ||
} | ||
if isinstance(contrastive_loss, int): | ||
logging_output["contrastive_loss"] = 0 | ||
else: | ||
logging_output["contrastive_loss"] = utils.item(contrastive_loss.data) | ||
|
||
return all_loss, sample_size, logging_output | ||
|
||
def similarity_function(self, ): | ||
return nn.CosineSimilarity(dim=-1) | ||
|
||
def get_contrastive_loss(self, encoder_out1, encoder_out2, sample1, sample2): | ||
|
||
def _sentence_embedding(encoder_out, sample): | ||
encoder_output = encoder_out.transpose(0, 1) | ||
src_tokens = sample["net_input"]["src_tokens"] | ||
mask = (src_tokens != self.padding_idx) | ||
encoder_embedding = (encoder_output * mask.unsqueeze(-1)).sum(dim=1) / mask.float().sum(dim=1).unsqueeze(-1) # [batch, hidden_size] | ||
return encoder_embedding | ||
|
||
encoder_embedding1 = _sentence_embedding(encoder_out1, sample1) # [batch, hidden_size] | ||
encoder_embedding2 = _sentence_embedding(encoder_out2, sample2) # [batch, hidden_size] | ||
|
||
batch_size = encoder_embedding2.shape[0] | ||
feature_dim = encoder_embedding2.shape[1] | ||
anchor_feature = encoder_embedding1 | ||
contrast_feature = encoder_embedding2 | ||
|
||
similarity_function = self.similarity_function() | ||
anchor_dot_contrast = similarity_function(anchor_feature.expand((batch_size, batch_size, feature_dim)), | ||
torch.transpose(contrast_feature.expand((batch_size, batch_size, feature_dim)), 0, 1)) | ||
|
||
loss = -nn.LogSoftmax(0)(torch.div(anchor_dot_contrast, self.temperature)).diag().sum() | ||
|
||
return loss | ||
|
||
@classmethod | ||
def reduce_metrics(cls, logging_outputs) -> None: | ||
super().reduce_metrics(logging_outputs) | ||
nsentences = utils.item( | ||
sum(log.get("nsentences", 0) for log in logging_outputs) | ||
) | ||
contrastive_loss = utils.item( | ||
sum(log.get("contrastive_loss", 0) for log in logging_outputs) | ||
) | ||
metrics.log_scalar( | ||
"contrastive_loss", | ||
contrastive_loss / nsentences / math.log(2), | ||
nsentences, | ||
round=3, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .subsample_language_pair_dataset import SubsampleLanguagePairDataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,124 @@ | ||
from fairseq.data import BaseWrapperDataset, LanguagePairDataset, plasma_utils | ||
import numpy as np | ||
|
||
import logging | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SubsampleLanguagePairDataset(BaseWrapperDataset): | ||
"""Subsamples a given dataset by a specified ratio. Subsampling is done on the number of examples | ||
Args: | ||
dataset (~torch.utils.data.Dataset): dataset to subsample | ||
size_ratio(float): the ratio to subsample to. must be between 0 and 1 (exclusive) | ||
""" | ||
|
||
def __init__(self, dataset, size_ratio, weights=None, replace=False, seed=0, epoch=1): | ||
super().__init__(dataset) | ||
assert size_ratio <= 1 | ||
self.actual_size = np.ceil(len(dataset) * size_ratio).astype(int) | ||
logger.info( | ||
"subsampled dataset from {} to {} (ratio={})".format( | ||
len(self.dataset), self.actual_size, size_ratio | ||
) | ||
) | ||
self.src_dict = self.dataset.src_dict | ||
self.tgt_dict = self.dataset.tgt_dict | ||
self.left_pad_source = self.dataset.left_pad_source | ||
self.left_pad_target = self.dataset.left_pad_target | ||
self.seed = seed | ||
self._cur_epoch = None | ||
self._cur_indices = None | ||
self.replace = replace | ||
if weights is None: | ||
self.weights = None | ||
else: | ||
assert len(weights) == len(dataset) | ||
weights_arr = np.array(weights, dtype=np.float64) | ||
weights_arr /= weights_arr.sum() | ||
self.weights = plasma_utils.PlasmaArray(weights_arr) | ||
self.set_epoch(epoch) | ||
|
||
def __getitem__(self, index): | ||
index = self._cur_indices.array[index] | ||
return self.dataset.__getitem__(index) | ||
|
||
def __len__(self): | ||
return self.actual_size | ||
|
||
@property | ||
def sizes(self): | ||
return self.dataset.sizes[self._cur_indices.array] | ||
|
||
@property | ||
def src_sizes(self): | ||
return self.dataset.src_sizes[self._cur_indices.array] | ||
|
||
@property | ||
def tgt_sizes(self): | ||
return self.dataset.tgt_sizes[self._cur_indices.array] | ||
|
||
@property | ||
def name(self): | ||
return self.dataset.name | ||
|
||
def num_tokens(self, index): | ||
index = self._cur_indices.array[index] | ||
return self.dataset.num_tokens(index) | ||
|
||
def size(self, index): | ||
index = self._cur_indices.array[index] | ||
return self.dataset.size(index) | ||
|
||
def ordered_indices(self): | ||
if self.shuffle: | ||
indices = np.random.permutation(len(self)).astype(np.int64) | ||
else: | ||
indices = np.arange(len(self), dtype=np.int64) | ||
# sort by target length, then source length | ||
if self.tgt_sizes is not None: | ||
indices = indices[np.argsort(self.tgt_sizes[indices], kind="mergesort")] | ||
return indices[np.argsort(self.src_sizes[indices], kind="mergesort")] | ||
|
||
def prefetch(self, indices): | ||
indices = self._cur_indices.array[indices] | ||
self.dataset.prefetch(indices) | ||
|
||
@property | ||
def can_reuse_epoch_itr_across_epochs(self): | ||
return False | ||
|
||
def set_epoch(self, epoch): | ||
logger.info("SubsampleLanguagePairDataset.set_epoch: {}".format(epoch)) | ||
super().set_epoch(epoch) | ||
|
||
if epoch == self._cur_epoch: | ||
return | ||
|
||
self._cur_epoch = epoch | ||
|
||
# Generate a weighted sample of indices as a function of the | ||
# random seed and the current epoch. | ||
|
||
rng = np.random.RandomState( | ||
[ | ||
42, # magic number | ||
self.seed % (2 ** 32), # global seed | ||
self._cur_epoch, # epoch index | ||
] | ||
) | ||
self._cur_indices = plasma_utils.PlasmaArray( | ||
rng.choice( | ||
len(self.dataset), | ||
self.actual_size, | ||
replace=self.replace, | ||
p=(None if self.weights is None else self.weights.array), | ||
) | ||
) | ||
|
||
logger.info( | ||
"Dataset is sub-sampled: {} -> {}, first 3 ids are: {}".format(len(self.dataset), self.actual_size, | ||
",".join( | ||
[str(_i) for _i in | ||
self._cur_indices.array[:3]]))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .translation_w_mono import * | ||
from .translation_w_langtok import * |
Oops, something went wrong.