Skip to content

Commit

Permalink
Merge pull request #105 from VikParuchuri/dev
Browse files Browse the repository at this point in the history
Speed up OCR 2x
  • Loading branch information
VikParuchuri committed May 18, 2024
2 parents 7a65c45 + fcfefa5 commit 74e8c0c
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 55 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "surya-ocr"
version = "0.4.6"
version = "0.4.7"
description = "OCR, layout, reading order, and line detection in 90+ languages"
authors = ["Vik Paruchuri <vik.paruchuri@gmail.com>"]
readme = "README.md"
Expand Down
5 changes: 3 additions & 2 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ def parallel_get_lines(preds, orig_sizes):
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
results = []
if settings.IN_STREAMLIT: # Ensures we don't parallelize with streamlit
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit, or with very few images
for i in range(len(images)):
result = parallel_get_lines(preds[i], orig_sizes[i])
results.append(result)
else:
with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor:
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(parallel_get_lines, preds, orig_sizes))

return results
Expand Down
34 changes: 12 additions & 22 deletions surya/input/processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import random
from typing import List

import cv2
import numpy as np
import math
import pypdfium2
Expand Down Expand Up @@ -83,35 +84,24 @@ def slice_polys_from_image(image: Image.Image, polys):
image_array = np.array(image)
lines = []
for idx, poly in enumerate(polys):
lines.append(slice_and_pad_poly(image, image_array, poly, idx))
lines.append(slice_and_pad_poly(image_array, poly))
return lines


def slice_and_pad_poly(image: Image.Image, image_array: np.array, coordinates, idx):
# Create a mask for the polygon
mask = Image.new('L', image.size, 0)

def slice_and_pad_poly(image_array: np.array, coordinates):
# Draw polygon onto mask
coordinates = [(corner[0], corner[1]) for corner in coordinates]
ImageDraw.Draw(mask).polygon(coordinates, outline=1, fill=1)
bbox = mask.getbbox()

if bbox is None:
return None

mask = np.array(mask)
bbox = [min([x[0] for x in coordinates]), min([x[1] for x in coordinates]), max([x[0] for x in coordinates]), max([x[1] for x in coordinates])]

# We mask out anything not in the polygon
polygon_image = image_array.copy()
polygon_image[mask == 0] = settings.RECOGNITION_PAD_VALUE

# Crop out the bbox, and ensure we pad the area outside the polygon with the pad value
cropped_polygon = polygon_image[bbox[1]:bbox[3], bbox[0]:bbox[2]]
rectangle = np.full((bbox[3] - bbox[1], bbox[2] - bbox[0], 3), settings.RECOGNITION_PAD_VALUE, dtype=np.uint8)
rectangle[:, :] = cropped_polygon
cropped_polygon = image_array[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy()
coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates]

# Paste the polygon into the rectangle
rectangle_image = Image.fromarray(rectangle)
# Pad the area outside the polygon with the pad value
mask = np.zeros_like(cropped_polygon, dtype=np.uint8)
cv2.fillPoly(mask, [np.int32(coordinates)], 1)

return rectangle_image
cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE
rectangle_image = Image.fromarray(cropped_polygon)

return rectangle_image
5 changes: 3 additions & 2 deletions surya/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,14 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op
id2label = model.config.id2label

results = []
if settings.IN_STREAMLIT: # Ensures we don't parallelize with streamlit
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit or too few images
for i in range(len(images)):
result = parallel_get_regions(preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
results.append(result)
else:
futures = []
with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor:
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for i in range(len(images)):
future = executor.submit(parallel_get_regions, preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
futures.append(future)
Expand Down
32 changes: 6 additions & 26 deletions surya/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,38 +62,18 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model
return predictions_by_image


def parallel_slice_polys(det_pred, image):
polygons = [p.polygon for p in det_pred.bboxes]
slices = slice_polys_from_image(image, polygons)
return slices


def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor, batch_size=None) -> List[OCRResult]:
det_predictions = batch_text_detection(images, det_model, det_processor)
if det_model.device.type == "cuda":
torch.cuda.empty_cache() # Empty cache from first model run

all_slices = []

if settings.IN_STREAMLIT:
all_slices = [parallel_slice_polys(det_pred, image) for det_pred, image in zip(det_predictions, images)]
else:
futures = []
with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor:
for image_idx in range(len(images)):
future = executor.submit(parallel_slice_polys, det_predictions[image_idx], images[image_idx])
futures.append(future)

for future in futures:
all_slices.append(future.result())

slice_map = []
all_langs = []
for idx, (slice, lang) in enumerate(zip(all_slices, langs)):
slice_map.append(len(slice))
all_langs.extend([lang] * len(slice))

all_slices = [slice for sublist in all_slices for slice in sublist]
for idx, (det_pred, image, lang) in enumerate(zip(det_predictions, images, langs)):
polygons = [p.polygon for p in det_pred.bboxes]
slices = slice_polys_from_image(image, polygons)
slice_map.append(len(slices))
all_langs.extend([lang] * len(slices))
all_slices.extend(slices)

rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size)

Expand Down
3 changes: 1 addition & 2 deletions surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,14 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
if batch_size is None:
batch_size = get_batch_size()

images = [image.convert("RGB") for image in images]

output_text = []
confidences = []

for i in tqdm(range(0, len(images), batch_size), desc="Recognizing Text"):
batch_langs = languages[i:i+batch_size]
has_math = ["_math" in lang for lang in batch_langs]
batch_images = images[i:i+batch_size]
batch_images = [image.convert("RGB") for image in batch_images]
model_inputs = processor(text=[""] * len(batch_langs), images=batch_images, lang=batch_langs)

batch_pixel_values = model_inputs["pixel_values"]
Expand Down
1 change: 1 addition & 0 deletions surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def TORCH_DEVICE_DETECTION(self) -> str:
DETECTOR_TEXT_THRESHOLD: float = 0.6 # Threshold for text detection (above this is considered text)
DETECTOR_BLANK_THRESHOLD: float = 0.35 # Threshold for blank space (below this is considered blank)
DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(8, os.cpu_count()) # Number of workers for postprocessing
DETECTOR_MIN_PARALLEL_THRESH: int = 3 # Minimum number of images before we parallelize

# Text recognition
RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec"
Expand Down

0 comments on commit 74e8c0c

Please sign in to comment.