Skip to content

Commit

Permalink
Add in efficient transformer
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Jan 29, 2023
1 parent 51dc209 commit 6e0be7b
Show file tree
Hide file tree
Showing 4 changed files with 945 additions and 63 deletions.
50 changes: 45 additions & 5 deletions data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class TextDatasetWrapper:
subset = None
x_length = None
target_length = None
add_stop = True

def __init__(self, vocab_size):
self.sp_vocab_size = vocab_size
Expand All @@ -41,13 +42,10 @@ def __init__(self, vocab_size):
self.pad_token = 0

self.extract_data()
tokenizer_data = ""
for split in self.splits:
x, target = self.split_x_target(self.data[split])
self.split_data[split] = {"x": x, "target": target}
tokenizer_data += "\n".join(x)
tokenizer_data += "\n".join(target)
self.train_tokenizer(tokenizer_data)
self.train_tokenizer_with_data()

self.model_file = f"{self.name}.model"
self.sp_model = load_sp_model(self.model_file)
Expand All @@ -62,7 +60,7 @@ def __init__(self, vocab_size):
if not self.target_length:
self.target_length = max([len(s) for s in target])
x = self.pad_sequences(x, self.x_length)
target = self.pad_sequences(target, self.y_length, use_stop=True)
target = self.pad_sequences(target, self.y_length, use_stop=self.add_stop)
self.encoded_data[split] = {"x": x, "target": target}
self.create_final_sets()

Expand All @@ -74,6 +72,13 @@ def extract_data(self):
s_data = s_data[:self.split_lengths[i]]
self.data[split] = s_data

def train_tokenizer_with_data(self):
tokenizer_data = ""
for split in self.splits:
tokenizer_data += "\n".join(self.split_data[split]["x"])
tokenizer_data += "\n".join(self.split_data[split]["target"])
self.train_tokenizer(tokenizer_data)

def split_x_target(self, split):
"""Override. Should return x and y"""
pass
Expand Down Expand Up @@ -231,4 +236,39 @@ def trim_length(self, x, target):
new_target.append(xr + tr[:(self.target_length - len(xr))])
else:
new_target.append(tr[:self.target_length])
return new_x, new_target


class CNNDatasetDecoderOnly(CNNDatasetWrapper):
x_length = 15
target_length = 15

def split_x_target(self, split):
chunks = [s.split(" ") for s in split]
chunks = [c for c in chunks if len(c) > self.x_length + self.target_length]
x = [" ".join(c[:self.x_length]) for c in chunks]
target = [" ".join(c[self.x_length:(self.x_length + self.target_length)]) for c in chunks]
return x, target

def train_tokenizer_with_data(self):
tokenizer_data = ""
for split in self.splits:
tokenizer_data += "\n".join(self.split_data[split]["x"])
self.train_tokenizer(tokenizer_data)

def trim_length(self, x, target):
new_x = []
new_target = []
for xv, tr in zip(x, target):
if len(xv) < self.x_length or len(tr) < self.target_length:
continue

# Ensure that we get a full sequence for the decoder, with a start token between
# start and end
seq_start = xv[:self.x_length]
seq_end = (xv[self.x_length:] + tr)[:self.target_length]
tn = seq_start + seq_end
xn = seq_start + [0] + seq_end
new_x.append(xn)
new_target.append(tn)
return new_x, new_target
Loading

0 comments on commit 6e0be7b

Please sign in to comment.