Skip to content

Commit

Permalink
Soft error whisper. (huggingface#22475)
Browse files Browse the repository at this point in the history
* Soft error whisper.

* Fix format.

---------

Co-authored-by: Ubuntu <ubuntu@ip-172-31-34-94.taildb5d.ts.net>
  • Loading branch information
Narsil and Ubuntu authored Apr 4, 2023
1 parent 98268b2 commit a515d0a
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 4 deletions.
4 changes: 1 addition & 3 deletions src/transformers/models/whisper/tokenization_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,9 +877,7 @@ def new_chunk():

if previous_tokens:
if return_timestamps:
# Last token should always be timestamps, so there shouldn't be
# leftover
raise ValueError(
logger.warning(
"There was an error while processing timestamps, we haven't found a timestamp as last token. Was"
" WhisperTimeStampLogitsProcessor used?"
)
Expand Down
33 changes: 32 additions & 1 deletion tests/pipelines/test_pipelines_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import numpy as np
import pytest
from datasets import load_dataset
from huggingface_hub import snapshot_download
from huggingface_hub import hf_hub_download, snapshot_download

from transformers import (
MODEL_FOR_CTC_MAPPING,
Expand All @@ -39,6 +39,7 @@
require_pyctcdecode,
require_tf,
require_torch,
require_torch_gpu,
require_torchaudio,
slow,
)
Expand Down Expand Up @@ -1158,6 +1159,36 @@ def test_stride(self):
output = speech_recognizer({"raw": waveform, "stride": (1000, 8000), "sampling_rate": 16_000})
self.assertEqual(output, {"text": "XB"})

@slow
@require_torch_gpu
def test_slow_unfinished_sequence(self):
from transformers import GenerationConfig

pipe = pipeline(
"automatic-speech-recognition",
model="vasista22/whisper-hindi-large-v2",
device="cuda:0",
)
# Original model wasn't trained with timestamps and has incorrect generation config
pipe.model.generation_config = GenerationConfig.from_pretrained("openai/whisper-large-v2")

audio = hf_hub_download("Narsil/asr_dummy", filename="hindi.ogg", repo_type="dataset")

out = pipe(
audio,
return_timestamps=True,
)
self.assertEqual(
out,
{
"chunks": [
{"text": "", "timestamp": (18.94, 0.0)},
{"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं", "timestamp": (None, None)},
],
"text": "मिर्ची में कितने विभिन्न प्रजातियां हैं",
},
)


def require_ffmpeg(test_case):
"""
Expand Down

0 comments on commit a515d0a

Please sign in to comment.