Skip to content

Commit

Permalink
[Wav2Vec2] Fix convert (huggingface#11562)
Browse files Browse the repository at this point in the history
* push

* small change

* correct other typo
  • Loading branch information
patrickvonplaten authored May 3, 2021
1 parent 623281a commit c448c01
Showing 1 changed file with 5 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,11 @@ def convert_wav2vec2_checkpoint(
if dict_path:
target_dict = Dictionary.load(dict_path)

config.bos_token_id = target_dict.bos_index
# important change bos & pad token id since CTC symbol is <pad> and
# not <s> as in fairseq
config.bos_token_id = target_dict.pad_index
config.pad_token_id = target_dict.bos_index
config.eos_token_id = target_dict.eos_index
config.pad_token_id = target_dict.pad_index
config.vocab_size = len(target_dict.symbols)
vocab_path = os.path.join(pytorch_dump_folder_path, "vocab.json")
if not os.path.isdir(pytorch_dump_folder_path):
Expand Down Expand Up @@ -214,9 +216,8 @@ def convert_wav2vec2_checkpoint(
hf_wav2vec = Wav2Vec2Model(config)

if is_finetuned:

model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task(
[checkpoint_path], arg_overrides={"data": dict_path}
[checkpoint_path], arg_overrides={"data": "/".join(dict_path.split("/")[:-1])}
)
else:
model, _, _ = fairseq.checkpoint_utils.load_model_ensemble_and_task([checkpoint_path])
Expand Down

0 comments on commit c448c01

Please sign in to comment.