forked from tesseract-ocr/tesseract
-
Notifications
You must be signed in to change notification settings - Fork 0
/
lstmrecognizer.h
394 lines (365 loc) · 18 KB
/
lstmrecognizer.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
///////////////////////////////////////////////////////////////////////
// File: lstmrecognizer.h
// Description: Top-level line recognizer class for LSTM-based networks.
// Author: Ray Smith
// Created: Thu May 02 08:57:06 PST 2013
//
// (C) Copyright 2013, Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
///////////////////////////////////////////////////////////////////////
#ifndef TESSERACT_LSTM_LSTMRECOGNIZER_H_
#define TESSERACT_LSTM_LSTMRECOGNIZER_H_
#include "ccutil.h"
#include "helpers.h"
#include "imagedata.h"
#include "matrix.h"
#include "network.h"
#include "networkscratch.h"
#include "recodebeam.h"
#include "series.h"
#include "strngs.h"
#include "unicharcompress.h"
class BLOB_CHOICE_IT;
struct Pix;
class ROW_RES;
class ScrollView;
class TBOX;
class WERD_RES;
namespace tesseract {
class Dict;
class ImageData;
// Enum indicating training mode control flags.
enum TrainingFlags {
TF_INT_MODE = 1,
TF_AUTO_HARDEN = 2,
TF_ROUND_ROBIN_TRAINING = 16,
TF_COMPRESS_UNICHARSET = 64,
};
// Top-level line recognizer class for LSTM-based networks.
// Note that a sub-class, LSTMTrainer is used for training.
class LSTMRecognizer {
public:
LSTMRecognizer();
~LSTMRecognizer();
int NumOutputs() const {
return network_->NumOutputs();
}
int training_iteration() const {
return training_iteration_;
}
int sample_iteration() const {
return sample_iteration_;
}
double learning_rate() const {
return learning_rate_;
}
bool IsHardening() const {
return (training_flags_ & TF_AUTO_HARDEN) != 0;
}
LossType OutputLossType() const {
if (network_ == nullptr) return LT_NONE;
StaticShape shape;
shape = network_->OutputShape(shape);
return shape.loss_type();
}
bool SimpleTextOutput() const { return OutputLossType() == LT_SOFTMAX; }
bool IsIntMode() const { return (training_flags_ & TF_INT_MODE) != 0; }
// True if recoder_ is active to re-encode text to a smaller space.
bool IsRecoding() const {
return (training_flags_ & TF_COMPRESS_UNICHARSET) != 0;
}
// Returns the cache strategy for the DocumentCache.
CachingStrategy CacheStrategy() const {
return training_flags_ & TF_ROUND_ROBIN_TRAINING ? CS_ROUND_ROBIN
: CS_SEQUENTIAL;
}
// Returns true if the network is a TensorFlow network.
bool IsTensorFlow() const { return network_->type() == NT_TENSORFLOW; }
// Returns a vector of layer ids that can be passed to other layer functions
// to access a specific layer.
GenericVector<STRING> EnumerateLayers() const {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
Series* series = reinterpret_cast<Series*>(network_);
GenericVector<STRING> layers;
series->EnumerateLayers(NULL, &layers);
return layers;
}
// Returns a specific layer from its id (from EnumerateLayers).
Network* GetLayer(const STRING& id) const {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
ASSERT_HOST(id.length() > 1 && id[0] == ':');
Series* series = reinterpret_cast<Series*>(network_);
return series->GetLayer(&id[1]);
}
// Returns the learning rate of the layer from its id.
float GetLayerLearningRate(const STRING& id) const {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
ASSERT_HOST(id.length() > 1 && id[0] == ':');
Series* series = reinterpret_cast<Series*>(network_);
return series->LayerLearningRate(&id[1]);
} else {
return learning_rate_;
}
}
// Multiplies the all the learning rate(s) by the given factor.
void ScaleLearningRate(double factor) {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
learning_rate_ *= factor;
if (network_->TestFlag(NF_LAYER_SPECIFIC_LR)) {
GenericVector<STRING> layers = EnumerateLayers();
for (int i = 0; i < layers.size(); ++i) {
ScaleLayerLearningRate(layers[i], factor);
}
}
}
// Multiplies the learning rate of the layer with id, by the given factor.
void ScaleLayerLearningRate(const STRING& id, double factor) {
ASSERT_HOST(network_ != NULL && network_->type() == NT_SERIES);
ASSERT_HOST(id.length() > 1 && id[0] == ':');
Series* series = reinterpret_cast<Series*>(network_);
series->ScaleLayerLearningRate(&id[1], factor);
}
// True if the network is using adagrad to train.
bool IsUsingAdaGrad() const { return network_->TestFlag(NF_ADA_GRAD); }
// Provides access to the UNICHARSET that this classifier works with.
const UNICHARSET& GetUnicharset() const { return ccutil_.unicharset; }
// Provides access to the Dict that this classifier works with.
const Dict* GetDict() const { return dict_; }
// Sets the sample iteration to the given value. The sample_iteration_
// determines the seed for the random number generator. The training
// iteration is incremented only by a successful training iteration.
void SetIteration(int iteration) {
sample_iteration_ = iteration;
}
// Accessors for textline image normalization.
int NumInputs() const {
return network_->NumInputs();
}
int null_char() const { return null_char_; }
// Writes to the given file. Returns false in case of error.
bool Serialize(TFile* fp) const;
// Reads from the given file. Returns false in case of error.
// If swap is true, assumes a big/little-endian swap is needed.
bool DeSerialize(bool swap, TFile* fp);
// Loads the dictionary if possible from the traineddata file.
// Prints a warning message, and returns false but otherwise fails silently
// and continues to work without it if loading fails.
// Note that dictionary load is independent from DeSerialize, but dependent
// on the unicharset matching. This enables training to deserialize a model
// from checkpoint or restore without having to go back and reload the
// dictionary.
bool LoadDictionary(const char* data_file_name, const char* lang);
// Recognizes the line image, contained within image_data, returning the
// ratings matrix and matching box_word for each WERD_RES in the output.
// If invert, tries inverted as well if the normal interpretation doesn't
// produce a good enough result. If use_alternates, the ratings matrix is
// filled with segmentation and classifier alternatives that may be searched
// using the standard beam search, otherwise, just a diagonal and prebuilt
// best_choice. The line_box is used for computing the box_word in the
// output words. Score_ratio is used to determine the classifier alternates.
// If one_word, then a single WERD_RES is formed, regardless of the spaces
// found during recognition.
// If not NULL, we attempt to translate the output to target_unicharset, but
// do not guarantee success, due to mismatches. In that case the output words
// are marked with our UNICHARSET, not the caller's.
void RecognizeLine(const ImageData& image_data, bool invert, bool debug,
double worst_dict_cert, bool use_alternates,
const UNICHARSET* target_unicharset, const TBOX& line_box,
float score_ratio, bool one_word,
PointerVector<WERD_RES>* words);
// Builds a set of tesseract-compatible WERD_RESs aligned to line_box,
// corresponding to the network output in outputs, labels, label_coords.
// one_word generates a single word output, that may include spaces inside.
// use_alternates generates alternative BLOB_CHOICEs and segmentation paths,
// with cut-offs determined by scale_factor.
// If not NULL, we attempt to translate the output to target_unicharset, but
// do not guarantee success, due to mismatches. In that case the output words
// are marked with our UNICHARSET, not the caller's.
void WordsFromOutputs(const NetworkIO& outputs,
const GenericVector<int>& labels,
const GenericVector<int> label_coords,
const TBOX& line_box, bool debug, bool use_alternates,
bool one_word, float score_ratio, float scale_factor,
const UNICHARSET* target_unicharset,
PointerVector<WERD_RES>* words);
// Helper computes min and mean best results in the output.
void OutputStats(const NetworkIO& outputs,
float* min_output, float* mean_output, float* sd);
// Recognizes the image_data, returning the labels,
// scores, and corresponding pairs of start, end x-coords in coords.
// If label_threshold is positive, uses it for making the labels, otherwise
// uses standard ctc. Returned in scale_factor is the reduction factor
// between the image and the output coords, for computing bounding boxes.
// If re_invert is true, the input is inverted back to its original
// photometric interpretation if inversion is attempted but fails to
// improve the results. This ensures that outputs contains the correct
// forward outputs for the best photometric interpretation.
// inputs is filled with the used inputs to the network, and if not null,
// target boxes is filled with scaled truth boxes if present in image_data.
bool RecognizeLine(const ImageData& image_data, bool invert, bool debug,
bool re_invert, float label_threshold, float* scale_factor,
NetworkIO* inputs, NetworkIO* outputs);
// Returns a tesseract-compatible WERD_RES from the line recognizer outputs.
// line_box should be the bounding box of the line image in the main image,
// outputs the output of the network,
// [word_start, word_end) the interval over which to convert,
// score_ratio for choosing alternate classifier choices,
// use_alternates to control generation of alternative segmentations,
// labels, label_coords, scale_factor from RecognizeLine above.
// If target_unicharset is not NULL, attempts to translate the internal
// unichar_ids to the target_unicharset, but falls back to untranslated ids
// if the translation should fail.
WERD_RES* WordFromOutput(const TBOX& line_box, const NetworkIO& outputs,
int word_start, int word_end, float score_ratio,
float space_certainty, bool debug,
bool use_alternates,
const UNICHARSET* target_unicharset,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
float scale_factor);
// Sets up a word with the ratings matrix and fake blobs with boxes in the
// right places.
WERD_RES* InitializeWord(const TBOX& line_box, int word_start, int word_end,
float space_certainty, bool use_alternates,
const UNICHARSET* target_unicharset,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
float scale_factor);
// Converts an array of labels to utf-8, whether or not the labels are
// augmented with character boundaries.
STRING DecodeLabels(const GenericVector<int>& labels);
// Displays the forward results in a window with the characters and
// boundaries as determined by the labels and label_coords.
void DisplayForward(const NetworkIO& inputs,
const GenericVector<int>& labels,
const GenericVector<int>& label_coords,
const char* window_name,
ScrollView** window);
protected:
// Sets the random seed from the sample_iteration_;
void SetRandomSeed() {
inT64 seed = static_cast<inT64>(sample_iteration_) * 0x10000001;
randomizer_.set_seed(seed);
randomizer_.IntRand();
}
// Displays the labels and cuts at the corresponding xcoords.
// Size of labels should match xcoords.
void DisplayLSTMOutput(const GenericVector<int>& labels,
const GenericVector<int>& xcoords,
int height, ScrollView* window);
// Prints debug output detailing the activation path that is implied by the
// xcoords.
void DebugActivationPath(const NetworkIO& outputs,
const GenericVector<int>& labels,
const GenericVector<int>& xcoords);
// Prints debug output detailing activations and 2nd choice over a range
// of positions.
void DebugActivationRange(const NetworkIO& outputs, const char* label,
int best_choice, int x_start, int x_end);
// Converts the network output to a sequence of labels. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The conversion method is determined by internal state.
void LabelsFromOutputs(const NetworkIO& outputs, float null_thr,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Converts the network output to a sequence of labels, using a threshold
// on the null_char_ to determine character boundaries. Outputs labels, scores
// and start xcoords of each char, and each null_char_, with an additional
// final xcoord for the end of the output.
// The label output is the one with the highest score in the interval between
// null_chars_.
void LabelsViaThreshold(const NetworkIO& output,
float null_threshold,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Converts the network output to a sequence of labels, with scores and
// start x-coords of the character labels. Retains the null_char_ character as
// the end x-coord, where already present, otherwise the start of the next
// character is the end.
// The number of labels, scores, and xcoords is always matched, except that
// there is always an additional xcoord for the last end position.
void LabelsViaCTC(const NetworkIO& output,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// As LabelsViaCTC except that this function constructs the best path that
// contains only legal sequences of subcodes for recoder_.
void LabelsViaReEncode(const NetworkIO& output, GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Converts the network output to a sequence of labels, with scores, using
// the simple character model (each position is a char, and the null_char_ is
// mainly intended for tail padding.)
void LabelsViaSimpleText(const NetworkIO& output,
GenericVector<int>* labels,
GenericVector<int>* xcoords);
// Helper returns a BLOB_CHOICE_LIST for the choices in a given x-range.
// Handles either LSTM labels or direct unichar-ids.
// Score ratio determines the worst ratio between top choice and remainder.
// If target_unicharset is not NULL, attempts to translate to the target
// unicharset, returning NULL on failure.
BLOB_CHOICE_LIST* GetBlobChoices(int col, int row, bool debug,
const NetworkIO& output,
const UNICHARSET* target_unicharset,
int x_start, int x_end, float score_ratio);
// Adds to the given iterator, the blob choices for the target_unicharset
// that correspond to the given LSTM unichar_id.
// Returns false if unicharset translation failed.
bool AddBlobChoices(int unichar_id, float rating, float certainty, int col,
int row, const UNICHARSET* target_unicharset,
BLOB_CHOICE_IT* bc_it);
// Returns a string corresponding to the label starting at start. Sets *end
// to the next start and if non-null, *decoded to the unichar id.
const char* DecodeLabel(const GenericVector<int>& labels, int start, int* end,
int* decoded);
// Returns a string corresponding to a given single label id, falling back to
// a default of ".." for part of a multi-label unichar-id.
const char* DecodeSingleLabel(int label);
protected:
// The network hierarchy.
Network* network_;
// The unicharset. Only the unicharset element is serialized.
// Has to be a CCUtil, so Dict can point to it.
CCUtil ccutil_;
// For backward compatibility, recoder_ is serialized iff
// training_flags_ & TF_COMPRESS_UNICHARSET.
// Further encode/decode ccutil_.unicharset's ids to simplify the unicharset.
UnicharCompress recoder_;
// ==Training parameters that are serialized to provide a record of them.==
STRING network_str_;
// Flags used to determine the training method of the network.
// See enum TrainingFlags above.
inT32 training_flags_;
// Number of actual backward training steps used.
inT32 training_iteration_;
// Index into training sample set. sample_iteration >= training_iteration_.
inT32 sample_iteration_;
// Index in softmax of null character. May take the value UNICHAR_BROKEN or
// ccutil_.unicharset.size().
inT32 null_char_;
// Range used for the initial random numbers in the weights.
float weight_range_;
// Learning rate and momentum multipliers of deltas in backprop.
float learning_rate_;
float momentum_;
// === NOT SERIALIZED.
TRand randomizer_;
NetworkScratch scratch_space_;
// Language model (optional) to use with the beam search.
Dict* dict_;
// Beam search held between uses to optimize memory allocation/use.
RecodeBeamSearch* search_;
// == Debugging parameters.==
// Recognition debug display window.
ScrollView* debug_win_;
};
} // namespace tesseract.
#endif // TESSERACT_LSTM_LSTMRECOGNIZER_H_