forked from namisan/mt-dnn
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
143 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
import os | ||
from sys import path | ||
path.append(os.getcwd()) | ||
from data_utils.task_def import DataFormat | ||
|
||
def load_conll_ner(file, is_train=True): | ||
rows = [] | ||
cnt = 0 | ||
sentence = [] | ||
label= [] | ||
with open(file, encoding="utf8") as f: | ||
for line in f: | ||
line = line.strip() | ||
if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n": | ||
if len(sentence) > 0: | ||
sample = {'uid': cnt, 'premise': sentence, 'label': label} | ||
rows.append(sample) | ||
sentence = [] | ||
label = [] | ||
cnt += 1 | ||
continue | ||
splits = line.split(' ') | ||
sentence.append(splits[0]) | ||
label.append(splits[-1]) | ||
if len(sentence) > 0: | ||
sample = {'uid': cnt, 'premise': sentence, 'label': label} | ||
return rows | ||
|
||
def load_conll_pos(file, is_train=True): | ||
rows = [] | ||
cnt = 0 | ||
sentence = [] | ||
label= [] | ||
with open(file, encoding="utf8") as f: | ||
for line in f: | ||
line = line.strip() | ||
if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n": | ||
if len(sentence) > 0: | ||
sample = {'uid': cnt, 'premise': sentence, 'label': label} | ||
rows.append(sample) | ||
sentence = [] | ||
label = [] | ||
cnt += 1 | ||
continue | ||
splits = line.split(' ') | ||
sentence.append(splits[0]) | ||
label.append(splits[1]) | ||
if len(sentence) > 0: | ||
sample = {'uid': cnt, 'premise': sentence, 'label': label} | ||
return rows | ||
|
||
def load_conll_chunk(file, is_train=True): | ||
rows = [] | ||
cnt = 0 | ||
sentence = [] | ||
label= [] | ||
with open(file, encoding="utf8") as f: | ||
for line in f: | ||
line = line.strip() | ||
if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n": | ||
if len(sentence) > 0: | ||
sample = {'uid': cnt, 'premise': sentence, 'label': label} | ||
rows.append(sample) | ||
sentence = [] | ||
label = [] | ||
cnt += 1 | ||
continue | ||
splits = line.split(' ') | ||
sentence.append(splits[0]) | ||
label.append(splits[2]) | ||
if len(sentence) > 0: | ||
sample = {'uid': cnt, 'premise': sentence, 'label': label} | ||
return rows | ||
|
||
def dump_rows(rows, out_path, data_format=DataFormat.Seqence): | ||
with open(out_path, "w", encoding="utf-8") as out_f: | ||
row0 = rows[0] | ||
for row in rows: | ||
if data_format == DataFormat.Seqence: | ||
for col in ["uid", "label", "premise"]: | ||
if "\t" in str(row[col]): | ||
import pdb; pdb.set_trace() | ||
out_f.write("%s\t%s\t%s\n" % (row["uid"], row["label"], row["premise"])) | ||
else: | ||
raise ValueError(data_format) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
import os | ||
import argparse | ||
from sys import path | ||
path.append(os.getcwd()) | ||
from data_utils.task_def import DataFormat | ||
from data_utils.log_wrapper import create_logger | ||
from experiments.ner.ner_utils import load_conll_chunk, load_conll_ner, load_conll_pos, dump_rows | ||
|
||
logger = create_logger(__name__, to_disk=True, log_file='bert_ner_data_proc_512_cased.log') | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser(description='Preprocessing English NER dataset.') | ||
parser.add_argument('--data_dir', type=str, required=True) | ||
parser.add_argument('--seed', type=int, default=13) | ||
parser.add_argument('--output_dir', type=str, required=True) | ||
args = parser.parse_args() | ||
return args | ||
|
||
def main(args): | ||
data_dir = args.data_dir | ||
data_dir = os.path.abspath(data_dir) | ||
if not os.path.isdir(data_dir): | ||
os.mkdir(data_dir) | ||
|
||
train_path = os.path.join(data_dir, 'train.txt') | ||
dev_path = os.path.join(data_dir, 'valid.txt') | ||
test_path = os.path.join(data_dir, 'test.txt') | ||
|
||
train_data = load_conll_ner(train_path) | ||
dev_data = load_conll_ner(dev_path) | ||
test_data = load_conll_ner(test_path) | ||
logger.info('Loaded {} NER train samples'.format(len(train_data))) | ||
logger.info('Loaded {} NER dev samples'.format(len(dev_data))) | ||
logger.info('Loaded {} NER test samples'.format(len(test_data))) | ||
|
||
bert_root = args.output_dir | ||
if not os.path.isdir(bert_root): | ||
os.mkdir(bert_root) | ||
train_fout = os.path.join(bert_root, 'ner_train.tsv') | ||
dev_fout = os.path.join(bert_root, 'ner_dev.tsv') | ||
test_fout = os.path.join(bert_root, 'ner_test.tsv') | ||
|
||
dump_rows(train_data, train_fout) | ||
dump_rows(dev_data, dev_fout) | ||
dump_rows(test_data, test_fout) | ||
logger.info('done with NER') | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
main(args) |