diff --git a/surya/detection.py b/surya/detection.py index 152c1e6..90c3fcb 100644 --- a/surya/detection.py +++ b/surya/detection.py @@ -123,12 +123,13 @@ def parallel_get_lines(preds, orig_sizes): def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]: preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size) results = [] - if settings.IN_STREAMLIT: # Ensures we don't parallelize with streamlit + if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit, or with very few images for i in range(len(images)): result = parallel_get_lines(preds[i], orig_sizes[i]) results.append(result) else: - with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor: + max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) + with ProcessPoolExecutor(max_workers=max_workers) as executor: results = list(executor.map(parallel_get_lines, preds, orig_sizes)) return results diff --git a/surya/layout.py b/surya/layout.py index fdb6ef4..68ed61b 100644 --- a/surya/layout.py +++ b/surya/layout.py @@ -186,13 +186,14 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op id2label = model.config.id2label results = [] - if settings.IN_STREAMLIT: # Ensures we don't parallelize with streamlit + if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit or too few images for i in range(len(images)): result = parallel_get_regions(preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None) results.append(result) else: futures = [] - with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor: + max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images)) + with ProcessPoolExecutor(max_workers=max_workers) as executor: for i in range(len(images)): future = executor.submit(parallel_get_regions, preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None) futures.append(future) diff --git a/surya/settings.py b/surya/settings.py index 3d01f06..4ffdac4 100644 --- a/surya/settings.py +++ b/surya/settings.py @@ -56,6 +56,7 @@ def TORCH_DEVICE_DETECTION(self) -> str: DETECTOR_TEXT_THRESHOLD: float = 0.6 # Threshold for text detection (above this is considered text) DETECTOR_BLANK_THRESHOLD: float = 0.35 # Threshold for blank space (below this is considered blank) DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(8, os.cpu_count()) # Number of workers for postprocessing + DETECTOR_MIN_PARALLEL_THRESH: int = 3 # Minimum number of images before we parallelize # Text recognition RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec"