diff --git a/TinyBERT/data_augmentation.py b/TinyBERT/data_augmentation.py index b817b865..6081cf18 100644 --- a/TinyBERT/data_augmentation.py +++ b/TinyBERT/data_augmentation.py @@ -73,7 +73,8 @@ def strip_accents(text): # valid string only includes al def _is_valid(string): - return True if not re.search('[^a-z]', string) else False + # Adding string.lower to also support cased model + return True if not re.search('[^a-z]', string.lower()) else False def _read_tsv(input_file, quotechar=None): @@ -141,9 +142,17 @@ def _word_distance(self, word): dist[word_idx] = -np.Inf candidate_ids = np.argsort(-dist)[:self.M] - return [self.ids_to_tokens[idx] for idx in candidate_ids][:self.M] + candidate_words = [self.ids_to_tokens[idx] for idx in candidate_ids][:self.M] + + if word.istitle(): + # capitialize the first letter of each word to preserve case + candidate_words = [w.title() for w in candidate_words] + return candidate_words def _masked_language_model(self, sent, word_pieces, mask_id): + if mask_id >= 512: + return [] + tokenized_text = self.tokenizer.tokenize(sent) tokenized_text = ['[CLS]'] + tokenized_text tokenized_len = len(tokenized_text) @@ -152,6 +161,8 @@ def _masked_language_model(self, sent, word_pieces, mask_id): if len(tokenized_text) > 512: tokenized_text = tokenized_text[:512] + if tokenized_len >= 512: + tokenized_len = 511 token_ids = self.tokenizer.convert_tokens_to_ids(tokenized_text) segments_ids = [0] * (tokenized_len + 1) + [1] * (len(tokenized_text) - tokenized_len - 1) @@ -206,7 +217,7 @@ def augment(self, sent): tokens = self.tokenizer.basic_tokenizer.tokenize(sent) candidate_words = {} for (idx, word) in enumerate(tokens): - if _is_valid(word) and word not in StopWordsList: + if _is_valid(word) and word.lower() not in StopWordsList: candidate_words[idx] = self._word_augment(sent, idx, word) logger.info(candidate_words) cnt = 0 @@ -269,7 +280,7 @@ def main(): parser = argparse.ArgumentParser() parser.add_argument("--pretrained_bert_model", default=None, type=str, required=True, - help="Downloaded pretrained model (bert-base-uncased) is under this folder") + help="Downloaded pretrained model (bert-base-cased/uncased) is under this folder") parser.add_argument("--glove_embs", default=None, type=str, required=True, help="Glove word embeddings file") parser.add_argument("--glue_dir", default=None, type=str, required=True,