Skip to content

Commit

Permalink
Add code example for document parsing
Browse files Browse the repository at this point in the history
  • Loading branch information
Niels Rogge authored and Niels Rogge committed Aug 11, 2022
1 parent 6dae1ba commit 65ab2e1
Showing 1 changed file with 43 additions and 2 deletions.
45 changes: 43 additions & 2 deletions docs/source/en/model_doc/donut.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,50 @@ into a single instance to both extract the input features and decode the predict
{'class': 'advertisement'}
```

The code is exactly the same for document parsing, except that the task prompt is different (e.g. "<s_cord-v2>").
- Step-by-step Document Parsing

Another example can be found below:
```py
>>> import re

>>> from transformers import DonutProcessor, VisionEncoderDecoderModel
>>> from datasets import load_dataset
>>> import torch

>>> processor = DonutProcessor.from_pretrained("nielsr/donut-base-finetuned-cord-v2")
>>> model = VisionEncoderDecoderModel.from_pretrained("nielsr/donut-base-finetuned-cord-v2")

>>> device = "cuda" if torch.cuda.is_available() else "cpu"
>>> model.to(device) # doctest: +IGNORE_RESULT

>>> # load document image
>>> dataset = load_dataset("hf-internal-testing/example-documents", split="test")
>>> image = dataset[2]["image"]

>>> # prepare decoder inputs
>>> task_prompt = "<s_cord-v2>"
>>> decoder_input_ids = processor.tokenizer(task_prompt, add_special_tokens=False, return_tensors="pt").input_ids

>>> pixel_values = processor(image, return_tensors="pt").pixel_values

>>> outputs = model.generate(
... pixel_values.to(device),
... decoder_input_ids=decoder_input_ids.to(device),
... max_length=model.decoder.config.max_position_embeddings,
... early_stopping=True,
... pad_token_id=processor.tokenizer.pad_token_id,
... eos_token_id=processor.tokenizer.eos_token_id,
... use_cache=True,
... num_beams=1,
... bad_words_ids=[[processor.tokenizer.unk_token_id]],
... return_dict_in_generate=True,
... )

>>> sequence = processor.batch_decode(outputs.sequences)[0]
>>> sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
>>> sequence = re.sub(r"<.*?>", "", sequence, count=1).strip() # remove first task start token
>>> print(processor.token2json(sequence))
{'menu': {'nm': 'CINNAMON SUGAR', 'unitprice': '17,000', 'cnt': '1 x', 'price': '17,000'}, 'sub_total': {'subtotal_price': '17,000'}, 'total': {'total_price': '17,000', 'cashprice': '20,000', 'changeprice': '3,000'}}
```

- Step-by-step Document Visual Question Answering (DocVQA)

Expand Down

0 comments on commit 65ab2e1

Please sign in to comment.