Skip to content

Commit

Permalink
made code base more maintainable
Browse files Browse the repository at this point in the history
  • Loading branch information
lhr0909 committed Feb 15, 2022
1 parent 9a7bdbb commit ba85450
Show file tree
Hide file tree
Showing 4 changed files with 245 additions and 431 deletions.
4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "gpt3-tokenizer",
"version": "1.1.1",
"version": "1.1.2",
"license": "MIT",
"author": "Simon Liang <simon@x-tech.io>",
"repository": {
Expand All @@ -20,7 +20,7 @@
},
"scripts": {
"start": "tsdx watch",
"build": "npm run build:browser && tsdx build",
"build": "npm run build:browser && tsdx build --target node --format cjs",
"build:browser": "rimraf dist-browser && tsdx build --target browser --format esm",
"test": "tsdx test",
"lint": "tsdx lint",
Expand Down
223 changes: 8 additions & 215 deletions src/index-browser.ts
Original file line number Diff line number Diff line change
@@ -1,230 +1,23 @@
import ArrayKeyedMap from 'array-keyed-map';

import TextEncoder from './text-encoder';
import TextDecoder from './text-decoder';
import GPT3Tokenizer from './tokenizer';

import bpeVocab from './bpe-vocab';
import bpeRegex from './bpe-regex';
import encodings from './encodings';

const range = (x: number, y: number) => {
const res = Array.from(Array(y).keys()).slice(x);
return res;
};

const ord = (x: string): number => {
return x.charCodeAt(0);
};

const chr = (n: number): string => {
return String.fromCharCode(n);
};

export default class GPT3Tokenizer {
private vocab: string;
private nMergedSpaces: number;
private nVocab: number;

private encodings: { [key: string]: number };
private decodings: { [key: number]: string };

export default class GPT3BrowserTokenizer extends GPT3Tokenizer {
private textEncoder: TextEncoder;
private textDecoder: TextDecoder;
private byteEncoder: Map<number, string>;
private byteDecoder: Map<string, number>;

private bpeRanks: ArrayKeyedMap<[string, string], number>;
private cache: { [key: string]: string };

constructor(options: { type: 'gpt3' | 'codex' }) {
this.encodings = encodings;
this.vocab = bpeVocab;
super(options);

this.textEncoder = new TextEncoder();
this.textDecoder = new TextDecoder();
this.nMergedSpaces = options.type === 'codex' ? 24 : 0;
this.nVocab = 50257 + this.nMergedSpaces;
this.decodings = {};
this.bpeRanks = new ArrayKeyedMap<[string, string], number>();
this.byteEncoder = new Map();
this.byteDecoder = new Map();
this.cache = {};
this.initialize();
}

initialize() {
if (this.vocab.length < 100) {
throw new Error('Tokenizer vocab file did not load correctly');
}
const vocabLines = this.vocab.split('\n');
const bpeMerges: [string, string][] = vocabLines
.slice(1, vocabLines.length - 1)
.map((line: string) =>
line.split(/(\s+)/).filter((part: string) => part.trim().length > 0)
) as [string, string][];

// add merged spaces for codex tokenizer
if (this.nMergedSpaces > 0) {
for (let i = 1; i < this.nMergedSpaces; i++) {
for (let j = 1; j < this.nMergedSpaces; j++) {
if (i + j <= this.nMergedSpaces) {
bpeMerges.push(['\u0120'.repeat(i), '\u0120'.repeat(j)]);
}
}
}

for (let i = 0; i < this.nMergedSpaces; i++) {
this.encodings['\u0120'.repeat(i + 2)] =
this.nVocab - this.nMergedSpaces + i;
}
}

for (const key of Object.keys(this.encodings)) {
this.decodings[this.encodings[key]] = key;
}

this.byteEncoder = this.bytesToUnicode();

this.byteEncoder.forEach((value, key) => {
this.byteDecoder.set(value, key);
});

this.zip(this.bpeRanks, bpeMerges, range(0, bpeMerges.length));
}

zip<X, Y>(result: Map<X, Y>, x: X[], y: Y[]): Map<X, Y> {
x.forEach((_, idx) => {
result.set(x[idx], y[idx]);
});

return result;
}

bytesToUnicode(): Map<number, string> {
const bs = range(ord('!'), ord('~') + 1).concat(
range(ord('\xa1'), ord('\xac') + 1),
range(ord('\xae'), ord('\xff') + 1)
);

let cs: any[] = bs.slice();
let n = 0;

for (let b = 0; b < Math.pow(2, 8); b++) {
if (!bs.includes(b)) {
bs.push(b);
cs.push(Math.pow(2, 8) + n);
n = n + 1;
}
}

cs = cs.map((c: number) => chr(c));

const result = new Map<number, string>();
this.zip(result, bs, cs as string[]);
return result;
}

getPairs(word: string[]): Set<[string, string]> {
const pairs = new Set<[string, string]>();
let prevChar = word[0];

for (let i = 1; i < word.length; i++) {
const char = word[i];
pairs.add([prevChar, char]);
prevChar = char;
}

return pairs;
}

bpe(token: string) {
if (token in this.cache) {
return this.cache[token];
}

let word: string[] | string = token.split('');

let pairs = this.getPairs(word);

if (!pairs || pairs.size === 0) {
return token;
}

while (true) {
const minPairs: { [key: number]: [string, string] } = {};
for (const pair of Array.from(pairs)) {
const rank = this.bpeRanks.get(pair);
minPairs[(isNaN(rank as number) ? 1e11 : rank as number)] = pair;
}

const bigram = minPairs[Math.min(...Object.keys(minPairs).map(x => parseInt(x)))];

if (!this.bpeRanks.has(bigram)) {
break;
}

const first = bigram[0];
const second = bigram[1];
let newWord: string[] = [];
let i = 0;

while (i < word.length) {
const j = word.indexOf(first, i);
if (j === -1) {
newWord = newWord.concat(word.slice(i));
break;
}
newWord = newWord.concat(word.slice(i, j));
i = j;

if (word[i] === first && i < word.length - 1 && word[i + 1] === second) {
newWord.push(first + second);
i = i + 2;
} else {
newWord.push(word[i]);
i = i + 1;
}
}

word = newWord;
if (word.length === 1) {
break;
} else {
pairs = this.getPairs(word);
}
}

word = word.join(' ');
this.cache[token] = word;

return word;
}

encode(text: string): { bpe: number[]; text: string[] } {
let bpeTokens: number[] = [];
let texts: string[] = [];
const matches = text.match(bpeRegex) || [];

for (let token of matches) {
token = Array.from(this.textEncoder.encode(token)).map((x) => this.byteEncoder.get(x)).join('');
const newTokens = this.bpe(token).split(' ').map((x) => this.encodings[x]);
bpeTokens = bpeTokens.concat(newTokens);
texts = texts.concat(
newTokens.map((x) => this.decode([x])),
);
}

return {
bpe: bpeTokens,
text: texts,
};
encodeUtf8(text: string): Uint8Array {
return this.textEncoder.encode(text);
}

decode(tokens: number[]): string {
const text = tokens.map((x) => this.decodings[x]).join('');
return this.textDecoder.decode(
new Uint8Array(
text.split('').map((x) => this.byteDecoder.get(x) as number),
),
);
decodeUtf8(bytes: Uint8Array): string {
return this.textDecoder.decode(bytes);
}
}
Loading

0 comments on commit ba85450

Please sign in to comment.