-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
中文文本行识别预测输出都是固定长度吗? #91
Comments
您好,我请问下训练中文的话alphabet该怎么设置呢 |
我这里的alphabet就是5990个中文字符连起来的字符串,类似这种形势,alphabet="一二三四五..."
|
感谢回复,我还想请教下,最后的$符号是必须的吗,我看默认的是'0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z:$',使用您提到的5990中文字符库的话,作为分隔的:是否应该去掉 |
import pickle alphabet = list(load('5990/5990.pickle').values()) 我是这么写的,'5990.pickle'是字典 {0: '卍', 1: ',', 2: '的', 3: '。', 4: '一', 5: '是', 6: '0', 7: '不', 8: '在', 9: '有', 10: '、',...} |
太感谢您了 |
作者您好, 我用的是合成的360万中文数据集数据做训练@https://github.com/YCG09/chinese_ocr
具体数据来源我忘了,我就想问一下,问什么我训练的模型输出结果都是固定长的10个字符,在训练集和测试集上效果还可以。但是处理正常的文档中的文本的效果并不好。想问一下作者,像这种情况有什么需要注意的地方,或者是什么参数需要修改
这是模型的部分输出
correct / total: 32683 / 36440,
Test loss: 0.042277, accuray: 0.896899
Epoch: 4/10; iter: 11000/28185; Loss: 0.033452; time: 110.62 s;
Epoch: 4/10; iter: 11100/28185; Loss: 0.031541; time: 37.33 s;
Epoch: 4/10; iter: 11200/28185; Loss: 0.029762; time: 37.43 s;
Epoch: 4/10; iter: 11300/28185; Loss: 0.031392; time: 37.54 s;
Epoch: 4/10; iter: 11400/28185; Loss: 0.032547; time: 37.77 s;
Epoch: 4/10; iter: 11500/28185; Loss: 0.033517; time: 38.46 s;
Epoch: 4/10; iter: 11600/28185; Loss: 0.033412; time: 40.39 s;
Epoch: 4/10; iter: 11700/28185; Loss: 0.030087; time: 39.20 s;
Epoch: 4/10; iter: 11800/28185; Loss: 0.031644; time: 38.34 s;
Epoch: 4/10; iter: 11900/28185; Loss: 0.034695; time: 38.58 s;
Start val
correct / total: 32922 / 36440,
Test loss: 0.039564, accuray: 0.903458
Epoch: 4/10; iter: 12000/28185; Loss: 0.035296; time: 115.10 s;
Epoch: 4/10; iter: 12100/28185; Loss: 0.032580; time: 38.55 s;
Epoch: 4/10; iter: 12200/28185; Loss: 0.031885; time: 38.64 s;
数据集用的是5990个中文字符合成的文本行, 并使用了lmdb格式的数据进行训练的
create_dataset.py如下
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import pickle
import numpy as np
def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True
def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.items():
if type(v) == str:
v = v.encode()
txn.put(k.encode(), v)
def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.
def load(filename='5990.pickle'):
with open(filename, 'rb') as handle:
index2char = pickle.load(handle)
return index2char
if name == 'main':
tools/dataset.py做了如下的修改
class lmdbDataset(Dataset):
main.py
from future import print_function
import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import os
import tools.utils as utils
import tools.dataset as dataset
import time
from collections import OrderedDict
from models.moran import MORAN
parser = argparse.ArgumentParser()
parser.add_argument('--train_nips', required=True, help='path to dataset')
parser.add_argument('--train_cvpr', required=True, help='path to dataset')
parser.add_argument('--valroot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imgH', type=int, default=64, help='the height of the input image to network')
parser.add_argument('--imgW', type=int, default=200, help='the width of the input image to network')
parser.add_argument('--targetH', type=int, default=32, help='the width of the input image to network')
parser.add_argument('--targetW', type=int, default=100, help='the width of the input image to network')
parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state')
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate for Critic, default=0.00005')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--MORAN', default='', help="path to model (to continue training)")
parser.add_argument('--alphabet', type=str, default='0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z:$')
parser.add_argument('--sep', type=str, default=':')
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed')
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
parser.add_argument('--valInterval', type=int, default=500, help='Interval to be displayed')
parser.add_argument('--saveInterval', type=int, default=10000, help='Interval to be displayed')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)')
parser.add_argument('--sgd', action='store_true', help='Whether to use sgd (default is rmsprop)')
parser.add_argument('--BidirDecoder', action='store_true', help='Whether to use BidirDecoder')
opt = parser.parse_args()
print(opt)
assert opt.ngpu == 1, "Multi-GPU training is not supported yet, due to the variant lengths of the text in a batch."
if opt.experiment is None:
opt.experiment = 'expr'
os.system('mkdir {0}'.format(opt.experiment))
opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)
cudnn.benchmark = True
if not torch.cuda.is_available():
assert not opt.cuda, 'You don't have a CUDA device.'
if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")
import pickle
def load(filename='5990.pickle'):
with open(filename, 'rb') as handle:
index2char = pickle.load(handle)
return index2char
alphabet = list(load('5990/5990.pickle').values())
opt.alphabet = '♫'.join(alphabet) # .replace('ⅰ','Ⅰ').replace('ⅱ','Ⅱ').replace('ⅲ','Ⅲ').replace('ⅳ','Ⅳ').replace('ⅴ','Ⅴ').replace('ⅵ','Ⅵ').replace('ⅶ','Ⅶ').replace('ⅷ','Ⅷ').replace('ⅸ','Ⅸ').replace('р','p').replace('п', 'n').replace('о', 'o')
opt.sep = '♫'
1.1 训练集(nips+cvpr)
train_nips_dataset = dataset.lmdbDataset(root=opt.train_nips,
transform=dataset.resizeNormalize((opt.imgW, opt.imgH)), reverse=opt.BidirDecoder, alphabet=''.join(alphabet))
assert train_nips_dataset
train_cvpr_dataset = dataset.lmdbDataset(root=opt.train_cvpr,
transform=dataset.resizeNormalize((opt.imgW, opt.imgH)), reverse=opt.BidirDecoder, alphabet=''.join(alphabet))
assert train_cvpr_dataset
train_dataset = torch.utils.data.ConcatDataset([train_nips_dataset, train_cvpr_dataset])
train_dataset = train_nips_dataset
train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batchSize,
shuffle=False, sampler=dataset.randomSequentialSampler(train_dataset, opt.batchSize),
num_workers=int(opt.workers))
1.2 测试集
test_dataset = dataset.lmdbDataset(root=opt.valroot,
transform=dataset.resizeNormalize((opt.imgW, opt.imgH)), reverse=opt.BidirDecoder, alphabet=''.join(alphabet))
2. 输出类别
nclass = len(opt.alphabet.split(opt.sep))
print('nclass = ', nclass)
nc = 1
converter = utils.strLabelConverterForAttention(opt.alphabet, opt.sep)
3. 损失函数
criterion = torch.nn.CrossEntropyLoss()
opt.cuda = True
if opt.cuda:
MORAN = MORAN(nc, nclass, opt.nh, opt.targetH, opt.targetW, BidirDecoder=opt.BidirDecoder, CUDA=opt.cuda)
else:
MORAN = MORAN(nc, nclass, opt.nh, opt.targetH, opt.targetW, BidirDecoder=opt.BidirDecoder, inputDataType='torch.FloatTensor', CUDA=opt.cuda)
if opt.MORAN != '':
print('loading pretrained model from %s' % opt.MORAN)
if opt.cuda:
state_dict = torch.load(opt.MORAN)
else:
state_dict = torch.load(opt.MORAN, map_location='cpu')
MORAN_state_dict_rename = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "") # remove
module.
MORAN_state_dict_rename[name] = v
MORAN.load_state_dict(MORAN_state_dict_rename, strict=True)
image = torch.FloatTensor(opt.batchSize, nc, opt.imgH, opt.imgW)
text = torch.LongTensor(opt.batchSize * 5)
text_rev = torch.LongTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)
if opt.cuda:
MORAN.cuda()
MORAN = torch.nn.DataParallel(MORAN, device_ids=range(opt.ngpu))
image = image.cuda()
text = text.cuda()
text_rev = text_rev.cuda()
criterion = criterion.cuda()
image = Variable(image)
text = Variable(text)
text_rev = Variable(text_rev)
length = Variable(length)
loss averager
loss_avg = utils.averager()
setup optimizer
if opt.adam:
optimizer = optim.Adam(MORAN.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
elif opt.adadelta:
optimizer = optim.Adadelta(MORAN.parameters(), lr=opt.lr)
elif opt.sgd:
optimizer = optim.SGD(MORAN.parameters(), lr=opt.lr, momentum=0.9)
else:
optimizer = optim.RMSprop(MORAN.parameters(), lr=opt.lr)
def val(dataset, criterion, max_iter=1000):
print('Start val')
data_loader = torch.utils.data.DataLoader(
dataset, shuffle=False, batch_size=opt.batchSize, num_workers=int(opt.workers)) # opt.batchSize
val_iter = iter(data_loader)
max_iter = min(max_iter, len(data_loader))
n_correct = 0
n_total = 0
loss_avg = utils.averager()
def trainBatch():
data = train_iter.next()
if opt.BidirDecoder:
cpu_images, cpu_texts, cpu_texts_rev = data
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts, scanned=True)
t_rev, _ = converter.encode(cpu_texts_rev, scanned=True)
utils.loadData(text, t)
utils.loadData(text_rev, t_rev)
utils.loadData(length, l)
preds0, preds1 = MORAN(image, length, text, text_rev)
cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0))
else:
cpu_images, cpu_texts = data
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts, scanned=True)
utils.loadData(text, t)
utils.loadData(length, l)
preds = MORAN(image, length, text, text_rev)
cost = criterion(preds, text)
t0 = time.time()
acc = 0
acc_tmp = 0
for epoch in range(opt.niter):
The text was updated successfully, but these errors were encountered: