-
Notifications
You must be signed in to change notification settings - Fork 0
/
bpetoken.py
158 lines (139 loc) · 5.04 KB
/
bpetoken.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import collections
import re
import os
from lxml import etree
## For learning a BPE Encoding from a corpus, and to tokenise the input texts and queries after learning
## get mapping key -> frequency mapping
## key is a tuple split on characters by default.
## for example, (l, o, w, e, s, t, _ ) -> 5
def make_vocab(text):
vocab = {}
words = text.split()
words = [re.sub(r'\W+', '', word) for word in words]
for word in words:
chars = [ch for ch in word]
chars.append("_")
word_tuple = tuple(chars)
if word_tuple in vocab:
vocab[word_tuple] += 1
else:
vocab[word_tuple] = 1
return vocab
## Gives all pairs of sequences from vocabulary, with their frequencies
def get_pairs(vocab):
pairs = collections.defaultdict(int)
for key, freq in vocab.items():
for i in range(len(key)-1):
pair = (key[i], key[i+1])
pairs[pair] += freq
return pairs
## Merges neighboring chars in keys that have the given pair
## for example, (l, o, w, e, s, t, _) -> (l, o, w, es, t, _) if pair == (e, s)
def merge_vocab(pair, vocab):
keys_to_del = {}
for key, freq in vocab.items():
flag = False
for i in range(len(key)-1):
if (key[i] == pair[0] and key[i+1] == pair[1]):
flag = True
# pair = (key[i], key[i+1])
# pairs[pair] += freq
if flag:
keys_to_del[key] = freq
for key, freq in keys_to_del.items():
del vocab[key]
i = 0
str_init = []
boo = False
while i < len(key)-1:
if (key[i] == pair[0] and key[i+1] == pair[1]):
str_init.append(key[i] + key[i+1])
i += 2
boo = True
else:
str_init.append(key[i])
i += 1
boo = False
if not boo:
str_init.append(key[-1])
if (len(key) > 2):
if(key[-3] == pair[0] and key[-2] == pair[1]):
str_init.append(key[-1])
vocab[tuple(str_init)] = freq
#print(vocab)
return vocab
# finds the most frequent neighbours num_merges number of times
def learn_tokens(ts, num_merges):
print("Learning merges!")
merges = []
for i in range(num_merges):
pr = get_pairs(ts)
if len(pr) > 1:
best = max(pr, key = pr.get)
ts = merge_vocab(best, ts)
merges.append(best)
#print(merges)
return merges
#learn_tokens(text, 8)
## returns the vocabulary for learning tokens in this folder
def vocabularise_documents(foldername):
vocab = {}
parser = etree.XMLParser(strip_cdata=False)
directory_list = [os.path.join(foldername, file) for file in os.listdir(foldername)]
for i in range(0, len(directory_list), 4):
document = directory_list[i]
print("Learning document = " + document)
with open(document, 'r') as docUnAppended:
docs_string = docUnAppended.read()
tree = etree.fromstring("<INIT>\n"+ docs_string + "</INIT>", parser)
#print("Reading for vocab construction, document = " + document)
titles = []
for enclosing_tag in tree.findall("DOC/TITLE"): # Extract the text content within the enclosing tag
enclosing_tag_text = ' '.join(text.strip() for text in enclosing_tag.itertext())
titles.append(enclosing_tag_text)
contents = []
for multiline_tag in tree.findall(".//DOC/CONTENT"):
multiline_text = ''.join(multiline_tag.itertext()).strip()
contents.append(multiline_text)
for titl in titles:
words = titl.split()
for word in words:
word = word.lower()
word = re.sub(r'\W+', '', word)
chars = [ch for ch in word]
chars.append("_")
word_tuple = tuple(chars)
if word_tuple in vocab:
vocab[word_tuple] += 1
else:
vocab[word_tuple] = 1
for cont in contents:
words = cont.split()
for word in words:
word = word.lower()
word = re.sub(r'\W+', '', word)
chars = [ch for ch in word]
chars.append("_")
word_tuple = tuple(chars)
if word_tuple in vocab:
vocab[word_tuple] += 1
else:
vocab[word_tuple] = 1
return vocab
## Returns the top numerges mergers of pairs
def merges(foldername, numerges):
vocab = vocabularise_documents(foldername)
print("Reading for vocab construction for learning")
merges = learn_tokens(vocab, numerges)
print("Merges learnt!")
return merges
# text = "low low low low low lower lower newest newest newest newest newest newest widest widest widest"
# vocab = make_vocab(text)
# print(vocab)
# pr = get_pairs(vocab)
# print(pr)
# best = max(pr, key = pr.get)
# print(best)
# vocab = merge_vocab(best, vocab)
# print(vocab)
# learn(text, 4)