Skip to content

Commit

Permalink
Pad out languages if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 27, 2024
1 parent 407c34d commit 1eb828a
Showing 1 changed file with 9 additions and 1 deletion.
10 changes: 9 additions & 1 deletion surya/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,14 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor

batch_pixel_values = processed_batches["pixel_values"][i:i+batch_size]
batch_langs = processed_batches["langs"][i:i+batch_size]
max_lang_len = max([len(lang) for lang in batch_langs])

# Pad languages to max length if needed, to ensure we can convert to a tensor
for lang_idx in range(len(batch_langs)):
lang_len = len(batch_langs[lang_idx])
if lang_len < max_lang_len:
batch_langs[lang_idx] = [processor.tokenizer.pad_id] * (max_lang_len - lang_len) + batch_langs[lang_idx]

batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs]
current_batch_size = len(batch_pixel_values)

Expand Down Expand Up @@ -120,7 +128,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor
encoder_cache = [None] * layer_count
all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device)

with torch.no_grad():
with torch.no_grad(): # inference_mode doesn't work with torch.compile
# Run post-prefill tokens
while token_count < settings.RECOGNITION_MAX_TOKENS:
is_prefill = token_count == 0
Expand Down

0 comments on commit 1eb828a

Please sign in to comment.