Skip to content

Commit

Permalink
update the code to bert tokenization.
Browse files Browse the repository at this point in the history
  • Loading branch information
zaidalyafeai committed Aug 22, 2020
1 parent 63cbf12 commit bb38343
Showing 1 changed file with 47 additions and 1 deletion.
48 changes: 47 additions & 1 deletion tkseem/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def _split_word_cached(self, word, number_of_subwords):
all_binaries = self.cached[n, number_of_subwords - 1]
return [split_on_binary(word, binary) for binary in all_binaries]

def _tokenize_from_dict(self, text, freq_dict, cache=False, max_size=20):
def _tokenize_from_dict_deprecated(self, text, freq_dict, cache=False, max_size=20):
"""Tokenize using frequency based approach given a dictionary
Args:
Expand Down Expand Up @@ -195,6 +195,52 @@ def _tokenize_from_dict(self, text, freq_dict, cache=False, max_size=20):
output_tokens.append(str(token))
return output_tokens

#https://github.com/google-research/bert/blob/eedf5716ce1268e56f0a50264a88cafad334ac61/tokenization.py#L308
def _tokenize_from_dict(self, text, freq_dict, max_size=20):
"""Tokenize using frequency based approach given a dictionary
Args:
text (str): input string
freq_dict (dict): frequency dictionary
cache (bool, optional): faster approach. Defaults to False.
max_size (int, optional): maximum word size. Defaults to 20.
Returns:
[type]: [description]
"""

output_tokens = []
for token in text.split():
chars = list(token)
if len(chars) > max_size:
output_tokens.append(self.unk_token)
continue

is_bad = False
start = 0
sub_tokens = []
while start < len(chars):
end = len(chars)
cur_substr = None
while start < end:
substr = "".join(chars[start:end])
if start > 0:
substr = "##" + substr
if substr in freq_dict:
cur_substr = substr
break
end -= 1
if cur_substr is None:
is_bad = True
break
sub_tokens.append(cur_substr)
start = end
if is_bad:
output_tokens.append(self.unk_token)
else:
output_tokens.extend(sub_tokens)
return output_tokens

def _truncate_dict(self, freq_dict):
"""Truncate a frequency dictionary and add reserved tokens
Expand Down

0 comments on commit bb38343

Please sign in to comment.