diff --git a/pyproject.toml b/pyproject.toml index d800609..8c76bf9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "surya-ocr" -version = "0.4.6" +version = "0.4.7" description = "OCR, layout, reading order, and line detection in 90+ languages" authors = ["Vik Paruchuri "] readme = "README.md" diff --git a/surya/detection.py b/surya/detection.py index 152c1e6..90c3fcb 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -123,12 +123,13 @@ def parallel_get_lines(preds, orig_sizes): def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size) results = [] - if settings.IN_STREAMLIT: # Ensures we don't parallelize with streamlit + if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit, or with very few images for i in range(len(images)): result = parallel_get_lines(preds[i], orig_sizes[i]) results.append(result) else: - with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor: + max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) + with ProcessPoolExecutor(max_workers=max_workers) as executor: results = list(executor.map(parallel_get_lines, preds, orig_sizes)) return results diff --git a/surya/input/processing.py b/surya/input/processing.py index d857a61..feb6021 100644 --- a/surya/input/processing.py +++ b/surya/input/processing.py @@ -2,6 +2,7 @@ import random from typing import List +import cv2 import numpy as np import math import pypdfium2 @@ -83,35 +84,24 @@ def slice_polys_from_image(image: Image.Image, polys): image_array = np.array(image) lines = [] for idx, poly in enumerate(polys): - lines.append(slice_and_pad_poly(image, image_array, poly, idx)) + lines.append(slice_and_pad_poly(image_array, poly)) return lines -def slice_and_pad_poly(image: Image.Image, image_array: np.array, coordinates, idx): - # Create a mask for the polygon - mask = Image.new('L', image.size, 0) - +def slice_and_pad_poly(image_array: np.array, coordinates): # Draw polygon onto mask coordinates = [(corner[0], corner[1]) for corner in coordinates] - ImageDraw.Draw(mask).polygon(coordinates, outline=1, fill=1) - bbox = mask.getbbox() - - if bbox is None: - return None - - mask = np.array(mask) + bbox = [min([x[0] for x in coordinates]), min([x[1] for x in coordinates]), max([x[0] for x in coordinates]), max([x[1] for x in coordinates])] # We mask out anything not in the polygon - polygon_image = image_array.copy() - polygon_image[mask == 0] = settings.RECOGNITION_PAD_VALUE - - # Crop out the bbox, and ensure we pad the area outside the polygon with the pad value - cropped_polygon = polygon_image[bbox[1]:bbox[3], bbox[0]:bbox[2]] - rectangle = np.full((bbox[3] - bbox[1], bbox[2] - bbox[0], 3), settings.RECOGNITION_PAD_VALUE, dtype=np.uint8) - rectangle[:, :] = cropped_polygon + cropped_polygon = image_array[bbox[1]:bbox[3], bbox[0]:bbox[2]].copy() + coordinates = [(x - bbox[0], y - bbox[1]) for x, y in coordinates] - # Paste the polygon into the rectangle - rectangle_image = Image.fromarray(rectangle) + # Pad the area outside the polygon with the pad value + mask = np.zeros_like(cropped_polygon, dtype=np.uint8) + cv2.fillPoly(mask, [np.int32(coordinates)], 1) - return rectangle_image + cropped_polygon[mask == 0] = settings.RECOGNITION_PAD_VALUE + rectangle_image = Image.fromarray(cropped_polygon) + return rectangle_image \ No newline at end of file diff --git a/surya/layout.py b/surya/layout.py index fdb6ef4..68ed61b 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -186,13 +186,14 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op id2label = model.config.id2label results = [] - if settings.IN_STREAMLIT: # Ensures we don't parallelize with streamlit + if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit or too few images for i in range(len(images)): result = parallel_get_regions(preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None) results.append(result) else: futures = [] - with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor: + max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) + with ProcessPoolExecutor(max_workers=max_workers) as executor: for i in range(len(images)): future = executor.submit(parallel_get_regions, preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None) futures.append(future) diff --git a/surya/ocr.py b/surya/ocr.py index ee853da..3b072bc 100644 --- a/surya/ocr.py +++ b/surya/ocr.py @@ -62,38 +62,18 @@ def run_recognition(images: List[Image.Image], langs: List[List[str]], rec_model return predictions_by_image -def parallel_slice_polys(det_pred, image): - polygons = [p.polygon for p in det_pred.bboxes] - slices = slice_polys_from_image(image, polygons) - return slices - - def run_ocr(images: List[Image.Image], langs: List[List[str]], det_model, det_processor, rec_model, rec_processor, batch_size=None) -> List[OCRResult]: det_predictions = batch_text_detection(images, det_model, det_processor) - if det_model.device.type == "cuda": - torch.cuda.empty_cache() # Empty cache from first model run all_slices = [] - - if settings.IN_STREAMLIT: - all_slices = [parallel_slice_polys(det_pred, image) for det_pred, image in zip(det_predictions, images)] - else: - futures = [] - with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor: - for image_idx in range(len(images)): - future = executor.submit(parallel_slice_polys, det_predictions[image_idx], images[image_idx]) - futures.append(future) - - for future in futures: - all_slices.append(future.result()) - slice_map = [] all_langs = [] - for idx, (slice, lang) in enumerate(zip(all_slices, langs)): - slice_map.append(len(slice)) - all_langs.extend([lang] * len(slice)) - - all_slices = [slice for sublist in all_slices for slice in sublist] + for idx, (det_pred, image, lang) in enumerate(zip(det_predictions, images, langs)): + polygons = [p.polygon for p in det_pred.bboxes] + slices = slice_polys_from_image(image, polygons) + slice_map.append(len(slices)) + all_langs.extend([lang] * len(slices)) + all_slices.extend(slices) rec_predictions, confidence_scores = batch_recognition(all_slices, all_langs, rec_model, rec_processor, batch_size=batch_size) diff --git a/surya/recognition.py b/surya/recognition.py index 8b8be74..57b5ca6 100644 --- a/surya/recognition.py +++ b/surya/recognition.py @@ -28,8 +28,6 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor if batch_size is None: batch_size = get_batch_size() - images = [image.convert("RGB") for image in images] - output_text = [] confidences = [] @@ -37,6 +35,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor batch_langs = languages[i:i+batch_size] has_math = ["_math" in lang for lang in batch_langs] batch_images = images[i:i+batch_size] + batch_images = [image.convert("RGB") for image in batch_images] model_inputs = processor(text=[""] * len(batch_langs), images=batch_images, lang=batch_langs) batch_pixel_values = model_inputs["pixel_values"] diff --git a/surya/settings.py b/surya/settings.py index 3d01f06..4ffdac4 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -56,6 +56,7 @@ def TORCH_DEVICE_DETECTION(self) -> str: DETECTOR_TEXT_THRESHOLD: float = 0.6 # Threshold for text detection (above this is considered text) DETECTOR_BLANK_THRESHOLD: float = 0.35 # Threshold for blank space (below this is considered blank) DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(8, os.cpu_count()) # Number of workers for postprocessing + DETECTOR_MIN_PARALLEL_THRESH: int = 3 # Minimum number of images before we parallelize # Text recognition RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec"