-
Notifications
You must be signed in to change notification settings - Fork 97
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Repair known BUG and add instructions
- Loading branch information
1 parent
0b52fa7
commit 7cdd06e
Showing
37 changed files
with
83,364 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,369 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## BERT-for-Sequence-Labeling-and-Text-Classification" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## BERT information" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Take uncased_L-12_H-768_A-12 as an example, which contains the following three files:\n", | ||
"+ uncased_L-12_H-768_A-12/vocab.txt\n", | ||
"+ uncased_L-12_H-768_A-12/bert_config.json\n", | ||
"+ uncased_L-12_H-768_A-12/bert_model.ckpt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Sequence-Labeling-task 序列标注任务" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"> Examples of model training usage" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### ATIS" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_sequence_labeling.py \\\n", | ||
"--task_name=\"atis\" \\\n", | ||
"--do_train=True \\\n", | ||
"--do_eval=True \\\n", | ||
"--do_predict=True \\\n", | ||
"--data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n", | ||
"--vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
"--bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
"--init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n", | ||
"--max_seq_length=128 \\\n", | ||
"--train_batch_size=32 \\\n", | ||
"--learning_rate=2e-5 \\\n", | ||
"--num_train_epochs=3.0 \\\n", | ||
"--output_dir=./output_model/atis_Slot_Filling_epoch3/ " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### SNIPS" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_sequence_labeling.py \\\n", | ||
"--task_name=\"snips\" \\\n", | ||
"--do_train=True \\\n", | ||
"--do_eval=True \\\n", | ||
"--do_predict=True \\\n", | ||
"--data_dir=data/snips_Intent_Detection_and_Slot_Filling \\\n", | ||
"--vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
"--bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
"--init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n", | ||
"--max_seq_length=128 \\\n", | ||
"--train_batch_size=32 \\\n", | ||
"--learning_rate=2e-5 \\\n", | ||
"--num_train_epochs=3.0 \\\n", | ||
"--output_dir=./output_model/snips_Slot_Filling_epochs3/ " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### CoNLL2003NER" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_sequence_labeling.py \\\n", | ||
"--task_name=\"conll2003ner\" \\\n", | ||
"--do_train=True \\\n", | ||
"--do_eval=True \\\n", | ||
"--do_predict=True \\\n", | ||
"--data_dir=data/CoNLL2003_NER \\\n", | ||
"--vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
"--bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
"--init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n", | ||
"--max_seq_length=128 \\\n", | ||
"--train_batch_size=32 \\\n", | ||
"--learning_rate=2e-5 \\\n", | ||
"--num_train_epochs=3.0 \\\n", | ||
"--output_dir=./output_model/conll2003ner_epoch3/ " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Sequence labeling task prediction 序列标注任务预测 " | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_sequence_labeling.py \\\n", | ||
"--task_name=\"conll2003ner\" \\\n", | ||
"--do_predict=True \\\n", | ||
"--data_dir=data/CoNLL2003_NER \\\n", | ||
"--vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
"--bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
"--init_checkpoint=output_model/conll2003ner_epoch3/model.ckpt-653 \\\n", | ||
"--output_dir=./output_predict/conll2003ner_epoch3_ckpt653/ " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Text-Classification Train 文本分类任务训练 " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### ATIS Train" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_text_classification.py \\\n", | ||
" --task_name=atis \\\n", | ||
" --do_train=true \\\n", | ||
" --do_eval=true \\\n", | ||
" --data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n", | ||
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
" --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n", | ||
" --max_seq_length=128 \\\n", | ||
" --train_batch_size=32 \\\n", | ||
" --learning_rate=2e-5 \\\n", | ||
" --num_train_epochs=3.0 \\\n", | ||
" --output_dir=./output_model/atis_Intent_Detection_epochs3/" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### ATIS Make Predicte" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_text_classification.py \\\n", | ||
" --task_name=atis \\\n", | ||
" --do_predict=true \\\n", | ||
" --data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n", | ||
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
" --init_checkpoint=output_model/atis_Intent_Detection_epochs3/model.ckpt-419 \\\n", | ||
" --max_seq_length=128 \\\n", | ||
" --output_dir=./output_predict/atis_Intent_Detection_epoch3_ckpt419" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### SNIPS Make Predicte" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_text_classification.py \\\n", | ||
" --task_name=Snips \\\n", | ||
" --do_predict=true \\\n", | ||
" --data_dir=data/snips_Intent_Detection_and_Slot_Filling \\\n", | ||
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
" --init_checkpoint=output_model/snips_Intent_Detection_epochs3/model.ckpt-1226 \\\n", | ||
" --max_seq_length=128 \\\n", | ||
" --output_dir=./output_predict/snips_Intent_Detection_epoch3_ckpt1226/" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Joint task training 联合任务训练" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### SNIPS Train" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_sequence_labeling_and_text_classification.py \\\n", | ||
" --task_name=snips \\\n", | ||
" --do_train=true \\\n", | ||
" --do_eval=true \\\n", | ||
" --data_dir=data/snips_Intent_Detection_and_Slot_Filling \\\n", | ||
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
" --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n", | ||
" --num_train_epochs=3.0 \\\n", | ||
" --output_dir=./output_model/snips_join_task_epoch3/" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### ATIS Train" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_sequence_labeling_and_text_classification.py \\\n", | ||
" --task_name=Atis \\\n", | ||
" --do_train=true \\\n", | ||
" --do_eval=true \\\n", | ||
" --data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n", | ||
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
" --init_checkpoint=pretrained_model/uncased_L-12_H-768_A-12/bert_model.ckpt \\\n", | ||
" --num_train_epochs=3.0 \\\n", | ||
" --output_dir=./output_model/atis_join_task_epoch3/" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### ATIS Next Train" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_sequence_labeling_and_text_classification.py \\\n", | ||
" --task_name=Atis \\\n", | ||
" --do_train=true \\\n", | ||
" --do_eval=true \\\n", | ||
" --data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n", | ||
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
" --init_checkpoint=output_model/atis_join_task_epoch3/model.ckpt-1399 \\\n", | ||
" --num_train_epochs=3.0 \\\n", | ||
" --output_dir=./output_model/atis_join_task_epoch6/" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Joint Mission predict 联合任务预测 " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### SNIPS Make Predicte" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_sequence_labeling_and_text_classification.py \\\n", | ||
" --task_name=Snips \\\n", | ||
" --do_predict=true \\\n", | ||
" --data_dir=data/snips_Intent_Detection_and_Slot_Filling \\\n", | ||
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
" --init_checkpoint=output_model/snips_join_task_epoch3/model.ckpt-1000 \\\n", | ||
" --max_seq_length=128 \\\n", | ||
" --output_dir=./output_predict/snips_join_task_epoch3_ckpt1000" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### ATIS Make Predicte" | ||
] | ||
}, | ||
{ | ||
"cell_type": "raw", | ||
"metadata": {}, | ||
"source": [ | ||
"python run_sequence_labeling_and_text_classification.py \\\n", | ||
" --task_name=Atis \\\n", | ||
" --do_predict=true \\\n", | ||
" --data_dir=data/atis_Intent_Detection_and_Slot_Filling \\\n", | ||
" --vocab_file=pretrained_model/uncased_L-12_H-768_A-12/vocab.txt \\\n", | ||
" --bert_config_file=pretrained_model/uncased_L-12_H-768_A-12/bert_config.json \\\n", | ||
" --init_checkpoint=output_model/atis_join_task_epoch3/model.ckpt-1000\n", | ||
" --max_seq_length=128 \\\n", | ||
" --output_dir=./output_predict/atis_join_task_epoch30_ckpt1000" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.6.7" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import os | ||
import numpy as np | ||
from sklearn_metrics_function import show_metrics | ||
|
||
ATIS_intent_label = ['atis_abbreviation', 'atis_aircraft', 'atis_aircraft#atis_flight#atis_flight_no', | ||
'atis_airfare', 'atis_airfare#atis_flight', 'atis_airfare#atis_flight_time', | ||
'atis_airline', 'atis_airline#atis_flight_no', 'atis_airport', 'atis_capacity', | ||
'atis_cheapest', 'atis_city', 'atis_day_name', 'atis_distance', 'atis_flight', | ||
'atis_flight#atis_airfare', 'atis_flight#atis_airline', 'atis_flight_no', | ||
'atis_flight_no#atis_airline', 'atis_flight_time', 'atis_ground_fare', | ||
'atis_ground_service', 'atis_ground_service#atis_ground_fare', 'atis_meal', | ||
'atis_quantity', 'atis_restriction'] | ||
|
||
with open(os.path.join("ATIS_Intent", "label")) as label_f: | ||
label_list = [label.replace("\n", "") for label in label_f.readlines()] | ||
#print(len(label_list), label_list) | ||
|
||
predit_label_value = np.fromfile(os.path.join("ATIS_Intent", "test_results.tsv"), sep="\t") | ||
predit_label_value = predit_label_value.reshape(-1, len(ATIS_intent_label)) | ||
predit_label_value = np.argmax(predit_label_value, axis=1) | ||
predit_label = [ATIS_intent_label[label_index] for label_index in predit_label_value] | ||
|
||
#print(len(predit_label), predit_label) | ||
|
||
|
||
show_metrics(y_test=label_list, y_predict=predit_label, labels=ATIS_intent_label) |
Oops, something went wrong.