Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

中文文本行识别预测输出都是固定长度吗? #91

Closed
kUhNCwlVbsWXClUR opened this issue Jul 21, 2019 · 5 comments
Closed

中文文本行识别预测输出都是固定长度吗? #91

kUhNCwlVbsWXClUR opened this issue Jul 21, 2019 · 5 comments

Comments

@kUhNCwlVbsWXClUR
Copy link

作者您好, 我用的是合成的360万中文数据集数据做训练@https://github.com/YCG09/chinese_ocr
具体数据来源我忘了,我就想问一下,问什么我训练的模型输出结果都是固定长的10个字符,在训练集和测试集上效果还可以。但是处理正常的文档中的文本的效果并不好。想问一下作者,像这种情况有什么需要注意的地方,或者是什么参数需要修改

这是模型的部分输出
correct / total: 32683 / 36440,
Test loss: 0.042277, accuray: 0.896899
Epoch: 4/10; iter: 11000/28185; Loss: 0.033452; time: 110.62 s;
Epoch: 4/10; iter: 11100/28185; Loss: 0.031541; time: 37.33 s;
Epoch: 4/10; iter: 11200/28185; Loss: 0.029762; time: 37.43 s;
Epoch: 4/10; iter: 11300/28185; Loss: 0.031392; time: 37.54 s;
Epoch: 4/10; iter: 11400/28185; Loss: 0.032547; time: 37.77 s;
Epoch: 4/10; iter: 11500/28185; Loss: 0.033517; time: 38.46 s;
Epoch: 4/10; iter: 11600/28185; Loss: 0.033412; time: 40.39 s;
Epoch: 4/10; iter: 11700/28185; Loss: 0.030087; time: 39.20 s;
Epoch: 4/10; iter: 11800/28185; Loss: 0.031644; time: 38.34 s;
Epoch: 4/10; iter: 11900/28185; Loss: 0.034695; time: 38.58 s;
Start val
correct / total: 32922 / 36440,
Test loss: 0.039564, accuray: 0.903458
Epoch: 4/10; iter: 12000/28185; Loss: 0.035296; time: 115.10 s;
Epoch: 4/10; iter: 12100/28185; Loss: 0.032580; time: 38.55 s;
Epoch: 4/10; iter: 12200/28185; Loss: 0.031885; time: 38.64 s;

数据集用的是5990个中文字符合成的文本行, 并使用了lmdb格式的数据进行训练的
create_dataset.py如下
import os
import lmdb # install lmdb by "pip install lmdb"
import cv2
import pickle
import numpy as np

def checkImageIsValid(imageBin):
if imageBin is None:
return False
imageBuf = np.fromstring(imageBin, dtype=np.uint8)
img = cv2.imdecode(imageBuf, cv2.IMREAD_GRAYSCALE)
imgH, imgW = img.shape[0], img.shape[1]
if imgH * imgW == 0:
return False
return True

def writeCache(env, cache):
with env.begin(write=True) as txn:
for k, v in cache.items():
if type(v) == str:
v = v.encode()
txn.put(k.encode(), v)

def createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True):
"""
Create LMDB dataset for CRNN training.

ARGS:
    outputPath    : LMDB output path
    imagePathList : list of image path
    labelList     : list of corresponding groundtruth texts
    lexiconList   : (optional) list of lexicon lists
    checkValid    : if true, check the validity of every image
"""
assert(len(imagePathList) == len(labelList))
nSamples = len(imagePathList)
env = lmdb.open(outputPath, map_size=1099511627776)
cache = {}
cnt = 1
for i in range(nSamples):
    imagePath = imagePathList[i]
    label = labelList[i]
    if not os.path.exists(imagePath):
        print('%s does not exist' % imagePath)
        continue
    with open(imagePath, 'rb') as f:
        imageBin = f.read()
    if checkValid:
        if not checkImageIsValid(imageBin):
            print('%s is not a valid image' % imagePath)
            continue

    imageKey = 'image-%09d' % cnt
    labelKey = 'label-%09d' % cnt
    cache[imageKey] = imageBin
    cache[labelKey] = label
    if lexiconList:
        lexiconKey = 'lexicon-%09d' % cnt
        cache[lexiconKey] = ' '.join(lexiconList[i])
    if cnt % 1000 == 0:
        writeCache(env, cache)
        cache = {}
        print('Written %d / %d' % (cnt, nSamples))
    cnt += 1
nSamples = cnt-1
cache['num-samples'] = str(nSamples)
writeCache(env, cache)
print('Created dataset with %d samples' % nSamples)

def load(filename='5990.pickle'):
with open(filename, 'rb') as handle:
index2char = pickle.load(handle)
return index2char

if name == 'main':

root_dir = '/datasets/images/5990'

"""
原始图片存放路径
/datasets/images/5990
    -- xxx1.jpg
    -- xxx2.jpg
    -- /xxx3.jpg
    -- /xxx4.jpg
    -- ....

    -- /train.txt
    -- /val.txt

val.txt  文件内容如下,字数是索引值 (分隔符是空格)
    xxxx.jpg 12 23 89 12 324 3243 3242 12 123 2349 
"""

index2char = load(filename="5990.pickle")
output = "/datasets/lmdb/5990"


if not os.path.exists(output):
    os.mkdir(output)

"""
output 
    需要输出的lmdb 文件路径
"""

for train in ['train', 'val']:
    outputPath = os.path.join(output, train)
    imagePathList = []
    labelList = []

    with open(os.path.join(root_dir, '{}.txt'.format(train))) as f:
        for line in f.readlines():
            txt = line.strip('\n').split()
            imagePath, label = txt[0], ''.join([index2char[int(index)] for index in  txt[1:]])
            imagePath = os.path.join(root_dir, imagePath)
            imagePathList.append(imagePath)
            labelList.append(label)
            # print(imagePath, label, len(label))

    createDataset(outputPath, imagePathList, labelList, lexiconList=None, checkValid=True)

tools/dataset.py做了如下的修改
class lmdbDataset(Dataset):

def __init__(self, root=None, transform=None, reverse=False, alphabet='0123456789abcdefghijklmnopqrstuvwxyz'):
    self.env = lmdb.open(
        root,
        max_readers=1,
        readonly=True,
        lock=False,
        readahead=False,
        meminit=False)

    if not self.env:
        print('cannot creat lmdb from %s' % (root))
        sys.exit(0)

    with self.env.begin(write=False) as txn:
        # yugengde   查看数据集创建时候是什么样子的,并进行还原
        nSamples = int(txn.get('num-samples'.encode()))
        self.nSamples = nSamples

    self.transform = transform
    self.alphabet = alphabet
    self.reverse = reverse

def __len__(self):
    return self.nSamples

def __getitem__(self, index):
    assert index <= len(self), 'index range error'
    index += 1
    with self.env.begin(write=False) as txn:
        label_key = 'label-%09d'.encode() % index
        label = txn.get(label_key).decode('utf-8')
        img_key = 'image-%09d'.encode() % index
        imgbuf = txn.get(img_key)

        buf = six.BytesIO()
        buf.write(imgbuf)
        buf.seek(0)
        try:
            img = Image.open(buf).convert('L')
        except IOError:
            print('Corrupted image for %d' % index)
            return self[index + 1]

        # label = ''.join([self.alphabet[int(i)-1] for i in label.split()])

        # label = ''.join(label[i] if label[i].lower() in self.alphabet else '' 
        #     for i in range(len(label)))
        # if len(label) <= 0:
        #     return self[index + 1]
        if self.reverse:
            label_rev = label[-1::-1]
            label_rev += self.alphabet[0]
        label += self.alphabet[0]
        
        if self.transform is not None:
            img = self.transform(img)
    if self.reverse:
        return (img, label, label_rev)
    else:
        return (img, label)

main.py
from future import print_function
import argparse
import random
import torch
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
import numpy as np
import os
import tools.utils as utils
import tools.dataset as dataset
import time
from collections import OrderedDict
from models.moran import MORAN

parser = argparse.ArgumentParser()
parser.add_argument('--train_nips', required=True, help='path to dataset')
parser.add_argument('--train_cvpr', required=True, help='path to dataset')
parser.add_argument('--valroot', required=True, help='path to dataset')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=2)
parser.add_argument('--batchSize', type=int, default=64, help='input batch size')
parser.add_argument('--imgH', type=int, default=64, help='the height of the input image to network')
parser.add_argument('--imgW', type=int, default=200, help='the width of the input image to network')
parser.add_argument('--targetH', type=int, default=32, help='the width of the input image to network')
parser.add_argument('--targetW', type=int, default=100, help='the width of the input image to network')
parser.add_argument('--nh', type=int, default=256, help='size of the lstm hidden state')
parser.add_argument('--niter', type=int, default=10, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.01, help='learning rate for Critic, default=0.00005')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--MORAN', default='', help="path to model (to continue training)")
parser.add_argument('--alphabet', type=str, default='0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z:$')
parser.add_argument('--sep', type=str, default=':')
parser.add_argument('--experiment', default=None, help='Where to store samples and models')
parser.add_argument('--displayInterval', type=int, default=500, help='Interval to be displayed')
parser.add_argument('--n_test_disp', type=int, default=10, help='Number of samples to display when test')
parser.add_argument('--valInterval', type=int, default=500, help='Interval to be displayed')
parser.add_argument('--saveInterval', type=int, default=10000, help='Interval to be displayed')
parser.add_argument('--adam', action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)')
parser.add_argument('--sgd', action='store_true', help='Whether to use sgd (default is rmsprop)')
parser.add_argument('--BidirDecoder', action='store_true', help='Whether to use BidirDecoder')
opt = parser.parse_args()
print(opt)

assert opt.ngpu == 1, "Multi-GPU training is not supported yet, due to the variant lengths of the text in a batch."

if opt.experiment is None:
opt.experiment = 'expr'
os.system('mkdir {0}'.format(opt.experiment))

opt.manualSeed = random.randint(1, 10000) # fix seed
print("Random Seed: ", opt.manualSeed)
random.seed(opt.manualSeed)
np.random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

cudnn.benchmark = True

if not torch.cuda.is_available():
assert not opt.cuda, 'You don't have a CUDA device.'

if torch.cuda.is_available() and not opt.cuda:
print("WARNING: You have a CUDA device, so you should probably run with --cuda")

import pickle
def load(filename='5990.pickle'):
with open(filename, 'rb') as handle:
index2char = pickle.load(handle)
return index2char

alphabet = list(load('5990/5990.pickle').values())
opt.alphabet = '♫'.join(alphabet) # .replace('ⅰ','Ⅰ').replace('ⅱ','Ⅱ').replace('ⅲ','Ⅲ').replace('ⅳ','Ⅳ').replace('ⅴ','Ⅴ').replace('ⅵ','Ⅵ').replace('ⅶ','Ⅶ').replace('ⅷ','Ⅷ').replace('ⅸ','Ⅸ').replace('р','p').replace('п', 'n').replace('о', 'o')
opt.sep = '♫'

1.1 训练集(nips+cvpr)

train_nips_dataset = dataset.lmdbDataset(root=opt.train_nips,
transform=dataset.resizeNormalize((opt.imgW, opt.imgH)), reverse=opt.BidirDecoder, alphabet=''.join(alphabet))
assert train_nips_dataset
train_cvpr_dataset = dataset.lmdbDataset(root=opt.train_cvpr,
transform=dataset.resizeNormalize((opt.imgW, opt.imgH)), reverse=opt.BidirDecoder, alphabet=''.join(alphabet))

assert train_cvpr_dataset

train_dataset = torch.utils.data.ConcatDataset([train_nips_dataset, train_cvpr_dataset])

train_dataset = train_nips_dataset

train_loader = torch.utils.data.DataLoader(
train_dataset, batch_size=opt.batchSize,
shuffle=False, sampler=dataset.randomSequentialSampler(train_dataset, opt.batchSize),
num_workers=int(opt.workers))

1.2 测试集

test_dataset = dataset.lmdbDataset(root=opt.valroot,
transform=dataset.resizeNormalize((opt.imgW, opt.imgH)), reverse=opt.BidirDecoder, alphabet=''.join(alphabet))

2. 输出类别

nclass = len(opt.alphabet.split(opt.sep))
print('nclass = ', nclass)
nc = 1

converter = utils.strLabelConverterForAttention(opt.alphabet, opt.sep)

3. 损失函数

criterion = torch.nn.CrossEntropyLoss()

opt.cuda = True

if opt.cuda:
MORAN = MORAN(nc, nclass, opt.nh, opt.targetH, opt.targetW, BidirDecoder=opt.BidirDecoder, CUDA=opt.cuda)
else:
MORAN = MORAN(nc, nclass, opt.nh, opt.targetH, opt.targetW, BidirDecoder=opt.BidirDecoder, inputDataType='torch.FloatTensor', CUDA=opt.cuda)

if opt.MORAN != '':
print('loading pretrained model from %s' % opt.MORAN)
if opt.cuda:
state_dict = torch.load(opt.MORAN)
else:
state_dict = torch.load(opt.MORAN, map_location='cpu')
MORAN_state_dict_rename = OrderedDict()
for k, v in state_dict.items():
name = k.replace("module.", "") # remove module.
MORAN_state_dict_rename[name] = v
MORAN.load_state_dict(MORAN_state_dict_rename, strict=True)

image = torch.FloatTensor(opt.batchSize, nc, opt.imgH, opt.imgW)
text = torch.LongTensor(opt.batchSize * 5)
text_rev = torch.LongTensor(opt.batchSize * 5)
length = torch.IntTensor(opt.batchSize)

if opt.cuda:
MORAN.cuda()
MORAN = torch.nn.DataParallel(MORAN, device_ids=range(opt.ngpu))
image = image.cuda()
text = text.cuda()
text_rev = text_rev.cuda()
criterion = criterion.cuda()

image = Variable(image)
text = Variable(text)
text_rev = Variable(text_rev)
length = Variable(length)

loss averager

loss_avg = utils.averager()

setup optimizer

if opt.adam:
optimizer = optim.Adam(MORAN.parameters(), lr=opt.lr, betas=(opt.beta1, 0.999))
elif opt.adadelta:
optimizer = optim.Adadelta(MORAN.parameters(), lr=opt.lr)
elif opt.sgd:
optimizer = optim.SGD(MORAN.parameters(), lr=opt.lr, momentum=0.9)
else:
optimizer = optim.RMSprop(MORAN.parameters(), lr=opt.lr)

def val(dataset, criterion, max_iter=1000):
print('Start val')
data_loader = torch.utils.data.DataLoader(
dataset, shuffle=False, batch_size=opt.batchSize, num_workers=int(opt.workers)) # opt.batchSize
val_iter = iter(data_loader)
max_iter = min(max_iter, len(data_loader))
n_correct = 0
n_total = 0
loss_avg = utils.averager()

# import pdb; pdb.set_trace()
for i in range(max_iter):
    data = val_iter.next()
    if opt.BidirDecoder:
        cpu_images, cpu_texts, cpu_texts_rev = data
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts, scanned=True)
        t_rev, _ = converter.encode(cpu_texts_rev, scanned=True)
        utils.loadData(text, t)
        utils.loadData(text_rev, t_rev)
        utils.loadData(length, l)
        preds0, preds1 = MORAN(image, length, text, text_rev, test=True)
        cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0))
        preds0_prob, preds0 = preds0.max(1)
        preds0 = preds0.view(-1)
        preds0_prob = preds0_prob.view(-1)
        # preds0.data: tensor([2567, 2567, 4529,  ..., 1029, 8064, 5824], device='cuda:0')  128*11=1408个数字
        # length.dta [11, 11, 11] 128个长度

        sim_preds0 = converter.decode(preds0.data, length.data)
        preds1_prob, preds1 = preds1.max(1)
        preds1 = preds1.view(-1)
        preds1_prob = preds1_prob.view(-1)
        sim_preds1 = converter.decode(preds1.data, length.data)
        sim_preds = []
        for j in range(cpu_images.size(0)):
            text_begin = 0 if j == 0 else length.data[:j].sum()
            # import pdb;pdb.set_trace()
            # if torch.mean(preds0_prob[text_begin:text_begin+len(sim_preds0[j].split('卍')[0]+'卍')]).data[0] >\
            #  torch.mean(preds1_prob[text_begin:text_begin+len(sim_preds1[j].split('卍')[0]+'卍')]).data[0]:
            #     sim_preds.append(sim_preds0[j].split('卍')[0]+'卍')
            # else:
            #     sim_preds.append(sim_preds1[j].split('卍')[0][-1::-1]+'卍')

            if torch.mean(preds0_prob[text_begin:text_begin+len(sim_preds0[j].split('卍')[0]+'卍')]).data >\
              torch.mean(preds1_prob[text_begin:text_begin+len(sim_preds1[j].split('卍')[0]+'卍')]).data:
                 sim_preds.append(sim_preds0[j].split('卍')[0]+'卍')
            else:
                 sim_preds.append(sim_preds1[j].split('卍')[0][-1::-1]+'卍')
            # print(sim_preds1[j].split('卍')[0][-1::-1]+'卍')
    else:
        cpu_images, cpu_texts = data
        utils.loadData(image, cpu_images)
        t, l = converter.encode(cpu_texts, scanned=True)
        utils.loadData(text, t)
        utils.loadData(length, l)
        preds = MORAN(image, length, text, text_rev, test=True)
        cost = criterion(preds, text)
        _, preds = preds.max(1)
        preds = preds.view(-1)
        sim_preds = converter.decode(preds.data, length.data)

    loss_avg.add(cost)
    for pred, target in zip(sim_preds, cpu_texts):
        if pred == target.lower():
            n_correct += 1
        n_total += 1

print("correct / total: %d / %d, "  % (n_correct, n_total))
accuracy = n_correct / float(n_total)
print('Test loss: %f, accuray: %f' % (loss_avg.val(), accuracy))
return accuracy

def trainBatch():
data = train_iter.next()
if opt.BidirDecoder:
cpu_images, cpu_texts, cpu_texts_rev = data
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts, scanned=True)
t_rev, _ = converter.encode(cpu_texts_rev, scanned=True)
utils.loadData(text, t)
utils.loadData(text_rev, t_rev)
utils.loadData(length, l)
preds0, preds1 = MORAN(image, length, text, text_rev)
cost = criterion(torch.cat([preds0, preds1], 0), torch.cat([text, text_rev], 0))
else:
cpu_images, cpu_texts = data
utils.loadData(image, cpu_images)
t, l = converter.encode(cpu_texts, scanned=True)
utils.loadData(text, t)
utils.loadData(length, l)
preds = MORAN(image, length, text, text_rev)
cost = criterion(preds, text)

MORAN.zero_grad()
cost.backward()
optimizer.step()
return cost

t0 = time.time()
acc = 0
acc_tmp = 0
for epoch in range(opt.niter):

train_iter = iter(train_loader)
i = 0
while i < len(train_loader):

    if i % opt.valInterval == 0 and i != 0:
        for p in MORAN.parameters():
            p.requires_grad = False
        MORAN.eval()

        acc_tmp = val(test_dataset, criterion)
        if acc_tmp > acc:
            acc = acc_tmp
            torch.save(MORAN.state_dict(), '{0}/{1}_{2}.pth'.format(
                    opt.experiment, i, str(acc)[:6]))

    if i % opt.saveInterval == 0:
        torch.save(MORAN.state_dict(), '{0}/{1}_{2}.pth'.format(
                    opt.experiment, epoch, i))

    for p in MORAN.parameters():
        p.requires_grad = True
    MORAN.train()

    cost = trainBatch()
    loss_avg.add(cost)
    
    if i % opt.displayInterval == 0:
        t1 = time.time()            
        print ('Epoch: %d/%d; iter: %d/%d; Loss: %f; time: %.2f s;' %
                (epoch, opt.niter, i, len(train_loader), loss_avg.val(), t1-t0)),
        loss_avg.reset()
        t0 = time.time()

    i += 1
@zhengjiawen
Copy link

您好,我请问下训练中文的话alphabet该怎么设置呢

@kUhNCwlVbsWXClUR
Copy link
Author

我这里的alphabet就是5990个中文字符连起来的字符串,类似这种形势,alphabet="一二三四五..."

您好,我请问下训练中文的话alphabet该怎么设置呢

#91 (comment)

@zhengjiawen
Copy link

我这里的alphabet就是5990个中文字符连起来的字符串,类似这种形势,alphabet="一二三四五..."

您好,我请问下训练中文的话alphabet该怎么设置呢

#91 (comment)

感谢回复,我还想请教下,最后的$符号是必须的吗,我看默认的是'0:1:2:3:4:5:6:7:8:9:a:b:c:d:e:f:g:h:i:j:k:l:m:n:o:p:q:r:s:t:u:v:w:x:y:z:$',使用您提到的5990中文字符库的话,作为分隔的:是否应该去掉

@kUhNCwlVbsWXClUR
Copy link
Author

import pickle
def load(filename='5990.pickle'):
with open(filename, 'rb') as handle:
index2char = pickle.load(handle)
return index2char

alphabet = list(load('5990/5990.pickle').values())
opt.alphabet = '♫'.join(alphabet) # .replace('ⅰ','Ⅰ').replace('ⅱ','Ⅱ').replace('ⅲ','Ⅲ').replace('ⅳ','Ⅳ').replace('ⅴ','Ⅴ').replace('ⅵ','Ⅵ').replace('ⅶ',
opt.sep = '♫'

我是这么写的,'5990.pickle'是字典 {0: '卍', 1: ',', 2: '的', 3: '。', 4: '一', 5: '是', 6: '0', 7: '不', 8: '在', 9: '有', 10: '、',...}
我理解的这里的$表示的是一个停顿符号,我用的卍来代替(也就是https://github.com/YCG09/chinese_ocr提到的),因为$在我的训练集中。

@zhengjiawen
Copy link

import pickle
def load(filename='5990.pickle'):
with open(filename, 'rb') as handle:
index2char = pickle.load(handle)
return index2char

alphabet = list(load('5990/5990.pickle').values())
opt.alphabet = '♫'.join(alphabet) # .replace('ⅰ','Ⅰ').replace('ⅱ','Ⅱ').replace('ⅲ','Ⅲ').replace('ⅳ','Ⅳ').replace('ⅴ','Ⅴ').replace('ⅵ','Ⅵ').replace('ⅶ',
opt.sep = '♫'

我是这么写的,'5990.pickle'是字典 {0: '卍', 1: ',', 2: '的', 3: '。', 4: '一', 5: '是', 6: '0', 7: '不', 8: '在', 9: '有', 10: '、',...}
我理解的这里的$表示的是一个停顿符号,我用的卍来代替(也就是https://github.com/YCG09/chinese_ocr提到的),因为$在我的训练集中。

太感谢您了

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants