Skip to content

Commit

Permalink
Fix graph break
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed May 27, 2024
1 parent 330b595 commit c1dd9c3
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
9 changes: 9 additions & 0 deletions benchmark/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,12 @@ def main():
parser.add_argument("--tesseract", action="store_true", help="Run tesseract instead of surya.", default=False)
parser.add_argument("--langs", type=str, help="Specify certain languages to benchmark.", default=None)
parser.add_argument("--tess_cpus", type=int, help="Number of CPUs to use for tesseract.", default=28)
parser.add_argument("--compile", action="store_true", help="Compile the model.", default=False)
args = parser.parse_args()

if args.compile:
assert settings.RECOGNITION_STATIC_CACHE, "You must set RECOGNITION_STATIC_CACHE to compile the model."

rec_model = load_recognition_model()
rec_processor = load_recognition_processor()

Expand Down Expand Up @@ -58,6 +62,11 @@ def main():
else:
lang_list.append(l)

if args.compile:
rec_model.decoder.model.decoder = torch.compile(rec_model.decoder.model.decoder)
# Run through one batch to compile the model
run_recognition(images[:1], lang_list[:1], rec_model, rec_processor, bboxes=bboxes[:1])

start = time.time()
predictions_by_image = run_recognition(images, lang_list, rec_model, rec_processor, bboxes=bboxes)
surya_time = time.time() - start
Expand Down
7 changes: 3 additions & 4 deletions surya/model/recognition/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,8 @@ def forward(self, hidden_states: torch.Tensor, langs: torch.LongTensor) -> torch
# Set weights to 1 if zero experts activated
routing_weights[torch.isinf(routing_weights)] = 1

unique_langs = langs.unique(dim=None).tolist()
unique_langs = [l for l in unique_langs if l in self.lang_codes]
unique_langs = sorted(unique_langs)
unique_langs = langs.unique(dim=None, sorted=True)
unique_langs = unique_langs[unique_langs > 3] # Remove start token

# Loop over all available experts in the model and perform the computation on each expert
for expert_lang in unique_langs:
Expand All @@ -97,7 +96,7 @@ def forward(self, hidden_states: torch.Tensor, langs: torch.LongTensor) -> torch
if idx.shape[0] == 0:
continue

expert_layer = self.experts[str(expert_lang)]
expert_layer = self.experts[str(expert_lang.item())]

current_state = hidden_states[idx]
current_hidden_states = expert_layer(current_state.view(-1, hidden_dim))
Expand Down

0 comments on commit c1dd9c3

Please sign in to comment.