Skip to content

Commit

Permalink
ner data def
Browse files Browse the repository at this point in the history
  • Loading branch information
namisan committed Oct 12, 2019
1 parent 3154388 commit 8b66f4e
Show file tree
Hide file tree
Showing 3 changed files with 143 additions and 4 deletions.
11 changes: 7 additions & 4 deletions data_utils/task_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

from enum import IntEnum
class TaskType(IntEnum):
Classification = 1
Regression = 2
Ranking = 3
Span = 4
Classification = 1
Regression = 2
Ranking = 3
Span = 4
SeqenceLabeling = 5



class DataFormat(IntEnum):
PremiseOnly = 1
PremiseAndOneHypothesis = 2
PremiseAndMultiHypothesis = 3
Seqence = 4

class EncoderModelType(IntEnum):
BERT = 1
Expand Down
85 changes: 85 additions & 0 deletions experiments/ner/ner_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import os
from sys import path
path.append(os.getcwd())
from data_utils.task_def import DataFormat

def load_conll_ner(file, is_train=True):
rows = []
cnt = 0
sentence = []
label= []
with open(file, encoding="utf8") as f:
for line in f:
line = line.strip()
if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n":
if len(sentence) > 0:
sample = {'uid': cnt, 'premise': sentence, 'label': label}
rows.append(sample)
sentence = []
label = []
cnt += 1
continue
splits = line.split(' ')
sentence.append(splits[0])
label.append(splits[-1])
if len(sentence) > 0:
sample = {'uid': cnt, 'premise': sentence, 'label': label}
return rows

def load_conll_pos(file, is_train=True):
rows = []
cnt = 0
sentence = []
label= []
with open(file, encoding="utf8") as f:
for line in f:
line = line.strip()
if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n":
if len(sentence) > 0:
sample = {'uid': cnt, 'premise': sentence, 'label': label}
rows.append(sample)
sentence = []
label = []
cnt += 1
continue
splits = line.split(' ')
sentence.append(splits[0])
label.append(splits[1])
if len(sentence) > 0:
sample = {'uid': cnt, 'premise': sentence, 'label': label}
return rows

def load_conll_chunk(file, is_train=True):
rows = []
cnt = 0
sentence = []
label= []
with open(file, encoding="utf8") as f:
for line in f:
line = line.strip()
if len(line)==0 or line.startswith('-DOCSTART') or line[0]=="\n":
if len(sentence) > 0:
sample = {'uid': cnt, 'premise': sentence, 'label': label}
rows.append(sample)
sentence = []
label = []
cnt += 1
continue
splits = line.split(' ')
sentence.append(splits[0])
label.append(splits[2])
if len(sentence) > 0:
sample = {'uid': cnt, 'premise': sentence, 'label': label}
return rows

def dump_rows(rows, out_path, data_format=DataFormat.Seqence):
with open(out_path, "w", encoding="utf-8") as out_f:
row0 = rows[0]
for row in rows:
if data_format == DataFormat.Seqence:
for col in ["uid", "label", "premise"]:
if "\t" in str(row[col]):
import pdb; pdb.set_trace()
out_f.write("%s\t%s\t%s\n" % (row["uid"], row["label"], row["premise"]))
else:
raise ValueError(data_format)
51 changes: 51 additions & 0 deletions experiments/ner/prepro.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import os
import argparse
from sys import path
path.append(os.getcwd())
from data_utils.task_def import DataFormat
from data_utils.log_wrapper import create_logger
from experiments.ner.ner_utils import load_conll_chunk, load_conll_ner, load_conll_pos, dump_rows

logger = create_logger(__name__, to_disk=True, log_file='bert_ner_data_proc_512_cased.log')

def parse_args():
parser = argparse.ArgumentParser(description='Preprocessing English NER dataset.')
parser.add_argument('--data_dir', type=str, required=True)
parser.add_argument('--seed', type=int, default=13)
parser.add_argument('--output_dir', type=str, required=True)
args = parser.parse_args()
return args

def main(args):
data_dir = args.data_dir
data_dir = os.path.abspath(data_dir)
if not os.path.isdir(data_dir):
os.mkdir(data_dir)

train_path = os.path.join(data_dir, 'train.txt')
dev_path = os.path.join(data_dir, 'valid.txt')
test_path = os.path.join(data_dir, 'test.txt')

train_data = load_conll_ner(train_path)
dev_data = load_conll_ner(dev_path)
test_data = load_conll_ner(test_path)
logger.info('Loaded {} NER train samples'.format(len(train_data)))
logger.info('Loaded {} NER dev samples'.format(len(dev_data)))
logger.info('Loaded {} NER test samples'.format(len(test_data)))

bert_root = args.output_dir
if not os.path.isdir(bert_root):
os.mkdir(bert_root)
train_fout = os.path.join(bert_root, 'ner_train.tsv')
dev_fout = os.path.join(bert_root, 'ner_dev.tsv')
test_fout = os.path.join(bert_root, 'ner_test.tsv')

dump_rows(train_data, train_fout)
dump_rows(dev_data, dev_fout)
dump_rows(test_data, test_fout)
logger.info('done with NER')


if __name__ == '__main__':
args = parse_args()
main(args)

0 comments on commit 8b66f4e

Please sign in to comment.