From 1eb828a5772fd69810bc5fcb3f1160b19d843c53 Mon Sep 17 00:00:00 2001 From: Vik Paruchuri Date: Mon, 27 May 2024 14:52:33 -0700 Subject: [PATCH] Pad out languages if needed --- surya/recognition.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/surya/recognition.py b/surya/recognition.py index bdce430..8853fe1 100644 --- a/surya/recognition.py +++ b/surya/recognition.py @@ -63,6 +63,14 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor batch_pixel_values = processed_batches["pixel_values"][i:i+batch_size] batch_langs = processed_batches["langs"][i:i+batch_size] + max_lang_len = max([len(lang) for lang in batch_langs]) + + # Pad languages to max length if needed, to ensure we can convert to a tensor + for lang_idx in range(len(batch_langs)): + lang_len = len(batch_langs[lang_idx]) + if lang_len < max_lang_len: + batch_langs[lang_idx] = [processor.tokenizer.pad_id] * (max_lang_len - lang_len) + batch_langs[lang_idx] + batch_decoder_input = [[model.config.decoder_start_token_id] + lang for lang in batch_langs] current_batch_size = len(batch_pixel_values) @@ -120,7 +128,7 @@ def batch_recognition(images: List, languages: List[List[str]], model, processor encoder_cache = [None] * layer_count all_done = torch.zeros(current_batch_size, dtype=torch.bool, device=model.device) - with torch.no_grad(): + with torch.no_grad(): # inference_mode doesn't work with torch.compile # Run post-prefill tokens while token_count < settings.RECOGNITION_MAX_TOKENS: is_prefill = token_count == 0