Skip to content

Commit

Permalink
Repair known BUG and add instructions
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanxiaosc committed May 22, 2019
1 parent 0b52fa7 commit 7cdd06e
Show file tree
Hide file tree
Showing 37 changed files with 83,364 additions and 1 deletion.
369 changes: 369 additions & 0 deletions Usage example 使用方法示例.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,369 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## BERT-for-Sequence-Labeling-and-Text-Classification"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## BERT information"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Take uncased_L-12_H-768_A-12 as an example, which contains the following three files:\n",
"+ uncased_L-12_H-768_A-12/vocab.txt\n",
"+ uncased_L-12_H-768_A-12/bert_config.json\n",
"+ uncased_L-12_H-768_A-12/bert_model.ckpt"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sequence-Labeling-task 序列标注任务"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Examples of model training usage"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ATIS"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_sequence_labeling.py \\\n",
"--task_name=\"atis\" \\\n",
"--do_train=True \\\n",
"--do_eval=True \\\n",
"--do_predict=True \\\n",
"--data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n",
"--vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
"--bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
"--init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n",
"--max_seq_length=128 \\\n",
"--train_batch_size=32 \\\n",
"--learning_rate=2e-5 \\\n",
"--num_train_epochs=3.0 \\\n",
"--output_dir=./output_model/atis_Slot_Filling_epoch3/ "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SNIPS"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_sequence_labeling.py \\\n",
"--task_name=\"snips\" \\\n",
"--do_train=True \\\n",
"--do_eval=True \\\n",
"--do_predict=True \\\n",
"--data_dir=data/snips_Intent_Detection_and_Slot_Filling \\\n",
"--vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
"--bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
"--init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n",
"--max_seq_length=128 \\\n",
"--train_batch_size=32 \\\n",
"--learning_rate=2e-5 \\\n",
"--num_train_epochs=3.0 \\\n",
"--output_dir=./output_model/snips_Slot_Filling_epochs3/ "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### CoNLL2003NER"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_sequence_labeling.py \\\n",
"--task_name=\"conll2003ner\" \\\n",
"--do_train=True \\\n",
"--do_eval=True \\\n",
"--do_predict=True \\\n",
"--data_dir=data/CoNLL2003_NER \\\n",
"--vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
"--bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
"--init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n",
"--max_seq_length=128 \\\n",
"--train_batch_size=32 \\\n",
"--learning_rate=2e-5 \\\n",
"--num_train_epochs=3.0 \\\n",
"--output_dir=./output_model/conll2003ner_epoch3/ "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Sequence labeling task prediction 序列标注任务预测 "
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_sequence_labeling.py \\\n",
"--task_name=\"conll2003ner\" \\\n",
"--do_predict=True \\\n",
"--data_dir=data/CoNLL2003_NER \\\n",
"--vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
"--bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
"--init_checkpoint=output_model/conll2003ner_epoch3/model.ckpt-653 \\\n",
"--output_dir=./output_predict/conll2003ner_epoch3_ckpt653/ "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Text-Classification Train 文本分类任务训练 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ATIS Train"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_text_classification.py \\\n",
" --task_name=atis \\\n",
" --do_train=true \\\n",
" --do_eval=true \\\n",
" --data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n",
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
" --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n",
" --max_seq_length=128 \\\n",
" --train_batch_size=32 \\\n",
" --learning_rate=2e-5 \\\n",
" --num_train_epochs=3.0 \\\n",
" --output_dir=./output_model/atis_Intent_Detection_epochs3/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ATIS Make Predicte"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_text_classification.py \\\n",
" --task_name=atis \\\n",
" --do_predict=true \\\n",
" --data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n",
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
" --init_checkpoint=output_model/atis_Intent_Detection_epochs3/model.ckpt-419 \\\n",
" --max_seq_length=128 \\\n",
" --output_dir=./output_predict/atis_Intent_Detection_epoch3_ckpt419"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SNIPS Make Predicte"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_text_classification.py \\\n",
" --task_name=Snips \\\n",
" --do_predict=true \\\n",
" --data_dir=data/snips_Intent_Detection_and_Slot_Filling \\\n",
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
" --init_checkpoint=output_model/snips_Intent_Detection_epochs3/model.ckpt-1226 \\\n",
" --max_seq_length=128 \\\n",
" --output_dir=./output_predict/snips_Intent_Detection_epoch3_ckpt1226/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Joint task training 联合任务训练"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SNIPS Train"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_sequence_labeling_and_text_classification.py \\\n",
" --task_name=snips \\\n",
" --do_train=true \\\n",
" --do_eval=true \\\n",
" --data_dir=data/snips_Intent_Detection_and_Slot_Filling \\\n",
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
" --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n",
" --num_train_epochs=3.0 \\\n",
" --output_dir=./output_model/snips_join_task_epoch3/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ATIS Train"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_sequence_labeling_and_text_classification.py \\\n",
" --task_name=Atis \\\n",
" --do_train=true \\\n",
" --do_eval=true \\\n",
" --data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n",
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
" --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n",
" --num_train_epochs=3.0 \\\n",
" --output_dir=./output_model/atis_join_task_epoch3/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ATIS Next Train"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_sequence_labeling_and_text_classification.py \\\n",
" --task_name=Atis \\\n",
" --do_train=true \\\n",
" --do_eval=true \\\n",
" --data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n",
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
" --init_checkpoint=output_model/atis_join_task_epoch3/model.ckpt-1399 \\\n",
" --num_train_epochs=3.0 \\\n",
" --output_dir=./output_model/atis_join_task_epoch6/"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Joint Mission predict 联合任务预测 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### SNIPS Make Predicte"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_sequence_labeling_and_text_classification.py \\\n",
" --task_name=Snips \\\n",
" --do_predict=true \\\n",
" --data_dir=data/snips_Intent_Detection_and_Slot_Filling \\\n",
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
" --init_checkpoint=output_model/snips_join_task_epoch3/model.ckpt-1000 \\\n",
" --max_seq_length=128 \\\n",
" --output_dir=./output_predict/snips_join_task_epoch3_ckpt1000"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### ATIS Make Predicte"
]
},
{
"cell_type": "raw",
"metadata": {},
"source": [
"python run_sequence_labeling_and_text_classification.py \\\n",
" --task_name=Atis \\\n",
" --do_predict=true \\\n",
" --data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n",
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n",
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n",
" --init_checkpoint=output_model/atis_join_task_epoch3/model.ckpt-1000\n",
" --max_seq_length=128 \\\n",
" --output_dir=./output_predict/atis_join_task_epoch30_ckpt1000"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.7"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
26 changes: 26 additions & 0 deletions calculating_model_score/calculate_atis_intent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import os
import numpy as np
from sklearn_metrics_function import show_metrics

ATIS_intent_label = ['atis_abbreviation', 'atis_aircraft', 'atis_aircraft#atis_flight#atis_flight_no',
'atis_airfare', 'atis_airfare#atis_flight', 'atis_airfare#atis_flight_time',
'atis_airline', 'atis_airline#atis_flight_no', 'atis_airport', 'atis_capacity',
'atis_cheapest', 'atis_city', 'atis_day_name', 'atis_distance', 'atis_flight',
'atis_flight#atis_airfare', 'atis_flight#atis_airline', 'atis_flight_no',
'atis_flight_no#atis_airline', 'atis_flight_time', 'atis_ground_fare',
'atis_ground_service', 'atis_ground_service#atis_ground_fare', 'atis_meal',
'atis_quantity', 'atis_restriction']

with open(os.path.join("ATIS_Intent", "label")) as label_f:
label_list = [label.replace("\n", "") for label in label_f.readlines()]
#print(len(label_list), label_list)

predit_label_value = np.fromfile(os.path.join("ATIS_Intent", "test_results.tsv"), sep="\t")
predit_label_value = predit_label_value.reshape(-1, len(ATIS_intent_label))
predit_label_value = np.argmax(predit_label_value, axis=1)
predit_label = [ATIS_intent_label[label_index] for label_index in predit_label_value]

#print(len(predit_label), predit_label)


show_metrics(y_test=label_list, y_predict=predit_label, labels=ATIS_intent_label)
Loading

0 comments on commit 7cdd06e

Please sign in to comment.