Skip to content

Commit

Permalink
Fix augmentation bug and support cased models
Browse files Browse the repository at this point in the history
  • Loading branch information
gowtham1997 committed Sep 27, 2021
1 parent 54ca698 commit 170f475
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions TinyBERT/data_augmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ def strip_accents(text):

# valid string only includes al
def _is_valid(string):
return True if not re.search('[^a-z]', string) else False
# Adding string.lower to also support cased model
return True if not re.search('[^a-z]', string.lower()) else False


def _read_tsv(input_file, quotechar=None):
Expand Down Expand Up @@ -141,9 +142,17 @@ def _word_distance(self, word):
dist[word_idx] = -np.Inf

candidate_ids = np.argsort(-dist)[:self.M]
return [self.ids_to_tokens[idx] for idx in candidate_ids][:self.M]
candidate_words = [self.ids_to_tokens[idx] for idx in candidate_ids][:self.M]

if word.istitle():
# capitialize the first letter of each word to preserve case
candidate_words = [w.title() for w in candidate_words]
return candidate_words

def _masked_language_model(self, sent, word_pieces, mask_id):
if mask_id >= 512:
return []

tokenized_text = self.tokenizer.tokenize(sent)
tokenized_text = ['[CLS]'] + tokenized_text
tokenized_len = len(tokenized_text)
Expand All @@ -152,6 +161,8 @@ def _masked_language_model(self, sent, word_pieces, mask_id):

if len(tokenized_text) > 512:
tokenized_text = tokenized_text[:512]
if tokenized_len >= 512:
tokenized_len = 511

token_ids = self.tokenizer.convert_tokens_to_ids(tokenized_text)
segments_ids = [0] * (tokenized_len + 1) + [1] * (len(tokenized_text) - tokenized_len - 1)
Expand Down Expand Up @@ -206,7 +217,7 @@ def augment(self, sent):
tokens = self.tokenizer.basic_tokenizer.tokenize(sent)
candidate_words = {}
for (idx, word) in enumerate(tokens):
if _is_valid(word) and word not in StopWordsList:
if _is_valid(word) and word.lower() not in StopWordsList:
candidate_words[idx] = self._word_augment(sent, idx, word)
logger.info(candidate_words)
cnt = 0
Expand Down Expand Up @@ -269,7 +280,7 @@ def main():
parser = argparse.ArgumentParser()

parser.add_argument("--pretrained_bert_model", default=None, type=str, required=True,
help="Downloaded pretrained model (bert-base-uncased) is under this folder")
help="Downloaded pretrained model (bert-base-cased/uncased) is under this folder")
parser.add_argument("--glove_embs", default=None, type=str, required=True,
help="Glove word embeddings file")
parser.add_argument("--glue_dir", default=None, type=str, required=True,
Expand Down

0 comments on commit 170f475

Please sign in to comment.