Skip to content

Commit

Permalink
Signed-off-by: Terry chan <napoler2008@gmail.com>
Browse files Browse the repository at this point in the history
  • Loading branch information
napoler committed Jul 6, 2020
1 parent 9cdcdd0 commit 803ed86
Show file tree
Hide file tree
Showing 4 changed files with 1,733 additions and 21 deletions.
3 changes: 3 additions & 0 deletions docs/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ python3 train.py --epochs 1 --device cpu --batch_size 320 --gradient_accumulatio

当然这是其中比较好的,大体还是不错的。

和使用mark模式不同的是,使用生成模型做训练的结果是可以提取出原文不存在的一些知识。比如所属专辑这种,不是原文中存在的文字,显然使用mark模式是做不到的。
不过生成模型也有时候会出现太过放飞自我的问题,比如出现不存在的名字或者实体等等。


## 感谢

Expand Down
42 changes: 22 additions & 20 deletions generate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,23 +127,25 @@ def get_ppl(start_text):
# ppl = math.exp(loss.mean().item())
# print(ppl)

# args = parser.parse_args()
file="data/train.txt"
fp=open(file,'r')
lines = fp.readlines()
for line in lines:
print("\n"*3)
print("###"*20)
print("原始语句:",line.split("[KGS]")[0])
try:
print("正确知识:",line.split("[KGS]")[1])
except:
pass
# print(line.split("[KGS]"))
start_text=line.split("[KGS]")[0]+" [KGS] "
print(get_kg(start_text))
# pre_text=get(start_text)
# p="".join(pre_text)
# # p.split("[/KGS]")[0]
# print("预测结果",p.split("[/KGS]")[0])
# # print(get_ppl(start_text+"".join(pre_text)))

if __name__=='__main__':
# args = parser.parse_args()
file="data/train.txt"
fp=open(file,'r')
lines = fp.readlines()
for line in lines:
print("\n"*3)
print("###"*20)
print("原始语句:",line.split("[KGS]")[0])
try:
print("正确知识:",line.split("[KGS]")[1])
except:
pass
# print(line.split("[KGS]"))
start_text=line.split("[KGS]")[0]+" [KGS] "
print(get_kg(start_text))
# pre_text=get(start_text)
# p="".join(pre_text)
# # p.split("[/KGS]")[0]
# print("预测结果",p.split("[/KGS]")[0])
# # print(get_ppl(start_text+"".join(pre_text)))
1,649 changes: 1,648 additions & 1 deletion notebook/reformer_pytorch_chinese.ipynb

Large diffs are not rendered by default.

60 changes: 60 additions & 0 deletions notebook/未命名.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"ename": "ModuleNotFoundError",
"evalue": "No module named '__main__.generate_test'; '__main__' is not a package",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-2-79fd179ea496>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0;32mfrom\u001b[0m \u001b[0;34m.\u001b[0m\u001b[0mgenerate_test\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named '__main__.generate_test'; '__main__' is not a package"
]
}
],
"source": [
"from .generate_test import *"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.9"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

0 comments on commit 803ed86

Please sign in to comment.