Skip to content

Commit

Permalink
Fix up parallel settings
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 18, 2024
1 parent 37e21b9 commit fcfefa5
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 4 deletions.
5 changes: 3 additions & 2 deletions surya/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,13 @@ def parallel_get_lines(preds, orig_sizes):
def batch_text_detection(images: List, model, processor, batch_size=None) -> List[TextDetectionResult]:
preds, orig_sizes = batch_detection(images, model, processor, batch_size=batch_size)
results = []
if settings.IN_STREAMLIT: # Ensures we don't parallelize with streamlit
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit, or with very few images
for i in range(len(images)):
result = parallel_get_lines(preds[i], orig_sizes[i])
results.append(result)
else:
with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor:
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
results = list(executor.map(parallel_get_lines, preds, orig_sizes))

return results
Expand Down
5 changes: 3 additions & 2 deletions surya/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,13 +186,14 @@ def batch_layout_detection(images: List, model, processor, detection_results: Op
id2label = model.config.id2label

results = []
if settings.IN_STREAMLIT: # Ensures we don't parallelize with streamlit
if settings.IN_STREAMLIT or len(images) < settings.DETECTOR_MIN_PARALLEL_THRESH: # Ensures we don't parallelize with streamlit or too few images
for i in range(len(images)):
result = parallel_get_regions(preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
results.append(result)
else:
futures = []
with ProcessPoolExecutor(max_workers=settings.DETECTOR_POSTPROCESSING_CPU_WORKERS) as executor:
max_workers = min(settings.DETECTOR_POSTPROCESSING_CPU_WORKERS, len(images))
with ProcessPoolExecutor(max_workers=max_workers) as executor:
for i in range(len(images)):
future = executor.submit(parallel_get_regions, preds[i], orig_sizes[i], id2label, detection_results[i] if detection_results else None)
futures.append(future)
Expand Down
1 change: 1 addition & 0 deletions surya/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def TORCH_DEVICE_DETECTION(self) -> str:
DETECTOR_TEXT_THRESHOLD: float = 0.6 # Threshold for text detection (above this is considered text)
DETECTOR_BLANK_THRESHOLD: float = 0.35 # Threshold for blank space (below this is considered blank)
DETECTOR_POSTPROCESSING_CPU_WORKERS: int = min(8, os.cpu_count()) # Number of workers for postprocessing
DETECTOR_MIN_PARALLEL_THRESH: int = 3 # Minimum number of images before we parallelize

# Text recognition
RECOGNITION_MODEL_CHECKPOINT: str = "vikp/surya_rec"
Expand Down

0 comments on commit fcfefa5

Please sign in to comment.