Skip to content

Commit

Permalink
Image transforms library (#18520)
Browse files Browse the repository at this point in the history
* Adapt FE methods to transforms library

* Mixin for saving the image processor

* Base processor skeleton

* BatchFeature for packaging image processor outputs

* Initial image processor for GLPN

* REmove accidental import

* Fixup and docs

* Mixin for saving the image processor

* Fixup and docs

* Import BatchFeature from feature_extraction_utils

* Fixup and docs

* Fixup and docs

* Fixup and docs

* Fixup and docs

* BatchFeature for packaging image processor outputs

* Import BatchFeature from feature_extraction_utils

* Import BatchFeature from feature_extraction_utils

* Fixup and docs

* Fixup and docs

* BatchFeature for packaging image processor outputs

* Import BatchFeature from feature_extraction_utils

* Fixup and docs

* Mixin for saving the image processor

* Fixup and docs

* Add rescale back and remove ImageType

* fix import mistake

* Fix enum var reference

* Can transform and specify image data format

* Remove redundant function

* Update reference

* Data format flag for rescale

* Fix typo

* Fix dimension check

* Fixes to make IP and FE outputs match

* Add tests for transforms

* Add test for utils

* Update some docstrings

* Make sure in channels last before converting to PIL

* Remove default to numpy batching

* Fix up

* Add docstring and model_input_types

* Use feature processor config from hub

* Alias GLPN feature extractor to image processor

* Alias feature extractor mixin

* Add return_numpy=False flag for resize

* Fix up

* Fix up

* Use different frameworks safely

* Safely import PIL

* Call function checking if PIL available

* Only import if vision available

* Address Sylvain PR comments
Co-authored-by: Sylvain.gugger@gmail.com

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update src/transformers/image_transforms.py

Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com>

* Update src/transformers/models/glpn/feature_extraction_glpn.py

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Add in docstrings

* Fix TFSwinSelfAttention to have relative position index as non-trainable weight (#18226)

Signed-off-by: Seunghwan Hong <seunghwan@scatterlab.co.kr>

* Refactor `TFSwinLayer` to increase serving compatibility (#18352)

* Refactor `TFSwinLayer` to increase serving compatibility

Signed-off-by: Seunghwan Hong <seunghwan@scatterlab.co.kr>

* Fix missed parameters while refactoring

Signed-off-by: Seunghwan Hong <seunghwan@scatterlab.co.kr>

* Fix window_reverse to calculate batch size

Signed-off-by: Seunghwan Hong <harrydrippin@gmail.com>
Co-Authored-By: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

Co-authored-by: amyeroberts <22614925+amyeroberts@users.noreply.github.com>

* Add TF prefix to TF-Res test class (#18481)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Remove py.typed (#18485)

* Fix pipeline tests (#18487)

* Fix pipeline tests

* Make sure all pipelines tests run with init changes

* Use new huggingface_hub tools for download models (#18438)

* Draft new cached_file

* Initial draft for config and model

* Small fixes

* Fix first batch of tests

* Look in cache when internet is down

* Fix last tests

* Bad black, not fixing all quality errors

* Make diff less

* Implement change for TF and Flax models

* Add tokenizer and feature extractor

* For compatibility with main

* Add utils to move the cache and auto-do it at first use.

* Quality

* Deal with empty commit shas

* Deal with empty etag

* Address review comments

* Fix `test_dbmdz_english` by updating expected values (#18482)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Move cache folder to huggingface/hub for consistency with hf_hub (#18492)

* Move cache folder to just huggingface

* Thank you VsCode for this needless import

* Move to hub

* Forgot one

* Update some expected values in `quicktour.mdx` for `resampy 0.3.0` (#18484)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Forgot one new_ for cache migration

* disable Onnx test for google/long-t5-tglobal-base (#18454)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Typo reported by Joel Grus on TWTR (#18493)

* Just re-reading the whole doc every couple of months 😬 (#18489)

* Delete valohai.yaml

* NLP => ML

* typo

* website supports https

* datasets

* 60k + modalities

* unrelated link fixing for accelerate

* Ok those links were actually broken

* Fix link

* Make `AutoTokenizer` auto-link

* wording tweak

* add at least one non-nlp task

* `transformers-cli login` => `huggingface-cli login` (#18490)

* zero chance anyone's using that constant no?

* `transformers-cli login` => `huggingface-cli login`

* `transformers-cli repo create` => `huggingface-cli repo create`

* `make style`

* Add seed setting to image classification example (#18519)

* [DX fix] Fixing QA pipeline streaming a dataset. (#18516)

* [DX fix] Fixing QA pipeline streaming a dataset.

QuestionAnsweringArgumentHandler would iterate over the whole dataset
effectively killing all properties of the pipeline.
This restores nice properties when using `Dataset` or `Generator` since
those are meant to be consumed lazily.

* Handling TF better.

* Clean up hub (#18497)

* Clean up utils.hub

* Remove imports

* More fixes

* Last fix

* update fsdp docs (#18521)

* updating fsdp documentation

* typo fix

* Fix compatibility with 1.12 (#17925)

* Fix compatibility with 1.12

* Remove pin from examples requirements

* Update torch scatter version

* Fix compatibility with 1.12

* Remove pin from examples requirements

* Update torch scatter version

* fix torch.onnx.symbolic_opset12 import

* Reject bad version

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Remove debug statement

* Specify en in doc-builder README example (#18526)

Co-authored-by: Ankur Goyal <ankur@impira.com>

* New cache fixes: add safeguard before looking in folders (#18522)

* unpin resampy (#18527)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* ✨ update to use interlibrary links instead of Markdown (#18500)

* Add example of multimodal usage to pipeline tutorial (#18498)

* 📝 add example of multimodal usage to pipeline tutorial

* 🖍 apply feedbacks

* 🖍 apply niels feedback

* [VideoMAE] Add model to doc tests (#18523)

* Add videomae to doc tests

* Add pip install decord

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>

* Update perf_train_gpu_one.mdx (#18532)

* Update no_trainer.py scripts to include accelerate gradient accumulation wrapper (#18473)

* Added accelerate gradient accumulation wrapper to run_image_classification_no_trainer.py example script

* make fixup changes

* PR comments

* changed input to Acceletor based on PR comment, ran make fixup

* Added comment explaining the sync_gradients statement

* Fixed lr scheduler max steps

* Changed run_clm_no_trainer.py script to use accelerate gradient accum wrapper

* Fixed all scripts except wav2vec2 pretraining to use accelerate gradient accum wrapper

* Added accelerate gradient accum wrapper for wav2vec2_pretraining_no_trainer.py script

* make fixup and lr_scheduler step inserted back into run_qa_beam_search_no_trainer.py

* removed changes to run_wav2vec2_pretraining_no_trainer.py script and fixed using wrong constant in qa_beam_search_no_trainer.py script

* Add Spanish translation of converting_tensorflow_models.mdx (#18512)

* Add file in spanish docs to be translated

* Finish translation to Spanish

* Improve Spanish  wording

* Add suggested changes from review

* Spanish translation of summarization.mdx (#15947) (#18477)

* Add Spanish translation of summarization.mdx

* Apply suggestions from code review

Co-authored-by: Omar U. Espejel <espejelomar@gmail.com>

Co-authored-by: Omar U. Espejel <espejelomar@gmail.com>

* Let's not cast them all (#18471)

* add correct dtypes when checking for params dtype

* forward contrib credits

* Update src/transformers/modeling_utils.py

Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>

* more comments

- added more comments on why we cast only floating point parameters

* Update src/transformers/modeling_utils.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

Co-authored-by: sgugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>

* fix: data2vec-vision Onnx ready-made configuration. (#18427)

* feat: add the data2vec conf that are missing https://huggingface.co/docs/transformers/serialization

* fix: wrong config

* Add mt5 onnx config (#18394)

* update features

* MT5OnnxConfig added with updated with tests and docs

* fix imports

* fix onnc_config_cls for mt5

Co-authored-by: Thomas Chaigneau <thomas.deeptools.ai>

* Minor update of `run_call_with_unpacked_inputs` (#18541)

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* BART - Fix attention mask device issue on copied models (#18540)

* attempt to fix attn mask device

* fix bart `_prepare_decoder_attention_mask`

- add correct device
- run `make fix-copies` to propagate the fix

* Adding a new `align_to_words` param to qa pipeline. (#18010)

* Adding a new `align_to_words` param to qa pipeline.

* Update src/transformers/pipelines/question_answering.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Import protection.

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* 📝 update metric with evaluate (#18535)

* Restore _init_weights value in no_init_weights (#18504)

* Recover _init_weights value in no_init_weights

For potential nested use. 
In addition, users might modify private no_init_weights as well.

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Remove private variable change check

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Clean up comment

* 📝 update documentation build section (#18548)

* `bitsandbytes` - `Linear8bitLt` integration into `transformers` models (#17901)

* first commit

* correct replace function

* add final changes

- works like charm!
- cannot implement tests yet
- tested

* clean up a bit

* add bitsandbytes dependencies

* working version

- added import function
- added bitsandbytes utils file

* small fix

* small fix

- fix import issue

* fix import issues

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* refactor a bit

- move bitsandbytes utils to utils
- change comments on functions

* reformat docstring

- reformat docstring on init_empty_weights_8bit

* Update src/transformers/__init__.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* revert bad formatting

* change to bitsandbytes

* refactor a bit

- remove init8bit since it is useless

* more refactoring

- fixed init empty weights issue
- added threshold param

* small hack to make it work

* Update src/transformers/modeling_utils.py

* Update src/transformers/modeling_utils.py

* revmoe the small hack

* modify utils file

* make style + refactor a bit

* create correctly device map

* add correct dtype for device map creation

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* apply suggestions

- remove with torch.grad
- do not rely on Python bool magic!

* add docstring

 - add docstring for new kwargs

* add docstring

- comment `replace_8bit_linear` function
- fix weird formatting

* - added more documentation
- added new utility function for memory footprint tracking
- colab demo to add

* few modifs

- typo doc
- force cast into float16 when load_in_8bit is enabled

* added colab link

* add test architecture + docstring a bit

* refactor a bit testing class

* make style + refactor a bit

* enhance checks

- add more checks
- start writing saving test

* clean up a bit

* male style

* add more details on doc

* add more tests

- still needs to fix 2 tests

* replace by "or"

- could not fix it from GitHub GUI

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* refactor a bit testing code + add readme

* make style

* fix import issue

* Update src/transformers/modeling_utils.py

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* add few comments

* add more doctring + make style

* more docstring

* raise error when loaded in 8bit

* make style

* add warning if loaded on CPU

* add small sanity check

* fix small comment

* add bitsandbytes on dockerfile

* Improve documentation

- improve documentation from comments

* add few comments

* slow tests pass on the VM but not on the CI VM

* Fix merge conflict

* make style

* another test should pass on a multi gpu setup

* fix bad import in testing file

* Fix slow tests

- remove dummy batches
- no more CUDA illegal memory errors

* odify dockerfile

* Update docs/source/en/main_classes/model.mdx

* Update Dockerfile

* Update model.mdx

* Update Dockerfile

* Apply suggestions from code review

* few modifications

- lm head can stay on disk/cpu
- change model name so that test pass

* change test value

- change test value to the correct output
- torch bmm changed to baddmm in bloom modeling when merging

* modify installation guidelines

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* replace `n`by `name`

* merge `load_in_8bit` and `low_cpu_mem_usage`

* first try - keep the lm head in full precision

* better check

- check the attribute `base_model_prefix` instead of computing the number of parameters

* added more tests

* Update src/transformers/utils/bitsandbytes.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Merge branch 'integration-8bit' of https://github.com/younesbelkada/transformers into integration-8bit

* improve documentation

- fix typos for installation
- change title in the documentation

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* TF: XLA-trainable DeBERTa v2 (#18546)

* fix deberta issues

* add different code paths for gpu and tpu

* shorter gpu take along axis

* Stable Dropout without tf cond

* variable must be float

* Preserve hub-related kwargs in AutoModel.from_pretrained (#18545)

* Preserve hub-related kwargs in AutoModel.from_pretrained

* Fix tests

* Remove debug statement

* TF Examples Rewrite (#18451)

* Finished QA example

* Dodge a merge conflict

* Update text classification and LM examples

* Update NER example

* New Keras metrics WIP, fix NER example

* Update NER example

* Update MC, summarization and translation examples

* Add XLA warnings when shapes are variable

* Make sure batch_size is consistently scaled by num_replicas

* Add PushToHubCallback to all models

* Add docs links for KerasMetricCallback

* Add docs links for prepare_tf_dataset and jit_compile

* Correct inferred model names

* Don't assume the dataset has 'lang'

* Don't assume the dataset has 'lang'

* Write metrics in text classification

* Add 'framework' to TrainingArguments and TFTrainingArguments

* Export metrics in all examples and add tests

* Fix training args for Flax

* Update command line args for translation test

* make fixup

* Fix accidentally running other tests in fp16

* Remove do_train/do_eval from run_clm.py

* Remove do_train/do_eval from run_mlm.py

* Add tensorflow tests to circleci

* Fix circleci

* Update examples/tensorflow/language-modeling/run_mlm.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update examples/tensorflow/test_tensorflow_examples.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update examples/tensorflow/translation/run_translation.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Update examples/tensorflow/token-classification/run_ner.py

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Fix save path for tests

* Fix some model card kwargs

* Explain the magical -1000

* Actually enable tests this time

* Skip text classification PR until we fix shape inference

* make fixup

Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>

* Use commit hash to look in cache instead of calling head (#18534)

* Use commit hash to look in cache instead of calling head

* Add tests

* Add attr for local configs too

* Stupid typos

* Fix tests

* Update src/transformers/utils/hub.py

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* Address Julien's comments

Co-authored-by: Julien Chaumond <julien@huggingface.co>

* `pipeline` support for `device="mps"` (or any other string) (#18494)

* `pipeline` support for `device="mps"` (or any other string)

* Simplify `if` nesting

* Update src/transformers/pipelines/base.py

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix? @sgugger

* passing `attr=None` is not the same as not passing `attr` 🤯

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update philosophy to include other preprocessing classes (#18550)

* 📝 update philosophy to include other preprocessing classes

* 🖍 apply feedbacks

* Properly move cache when it is not in default path (#18563)

* Adds CLIP to models exportable with ONNX (#18515)

* onnx config for clip

* default opset as 14

* changes from the original repo

* input values order fix

* outputs fix

* remove unused import

* ran make fix-copies

* black format

* review comments: forward ref, import fix, model change revert, .to cleanup

* make style

* formatting fixes

* revert groupvit

* comment for cast to int32

* comment fix

* make .T as .t() for onnx conversion

* ran make fix-copies

* remove unneeded comment

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* fix copies

* remove comment

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* raise atol for MT5OnnxConfig (#18560)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* fix string (#18568)

* Segformer TF: fix output size in documentation (#18572)

* Segformer TF: fix output size in doc

* Segformer pytorch: fix output size in doc

Co-authored-by: Maxime Gardoni <maxime.gardoni@ecorobotix.com>

* Fix resizing bug in OWL-ViT (#18573)

* Fixes resizing bug in OWL-ViT
* Defaults to square resize if size is set to an int
* Sets do_center_crop default value to False

* Fix LayoutLMv3 documentation (#17932)

* fix typos

* fix sequence_length docs of LayoutLMv3Model

* delete trailing white spaces

* fix layoutlmv3 docs more

* apply make fixup & quality

* change to two versions of input docstring

* apply make fixup & quality

* Skip broken tests

* Change BartLearnedPositionalEmbedding's forward method signature to support Opacus training (#18486)

* changing BartLearnedPositionalEmbedding forward signature and references to it

* removing debugging dead code (thanks style checker)

* blackened modeling_bart file

* removing copy inconsistencies via make fix-copies

* changing references to copied signatures in Bart variants

* make fix-copies once more

* using expand over repeat (thanks @michaelbenayoun)

* expand instead of repeat for all model copies

Co-authored-by: Daniel Jones <jonesdaniel@microsoft.com>

* german docs translation (#18544)

* Create _config.py

* Create _toctree.yml

* Create index.mdx

not sure about "du / ihr" oder "sie"

* Create quicktour.mdx

* Update _toctree.yml

* Update build_documentation.yml

* Update build_pr_documentation.yml

* fix build

* Update index.mdx

* Update quicktour.mdx

* Create installation.mdx

* Update _toctree.yml

* Deberta V2: Fix critical trace warnings to allow ONNX export (#18272)

* Fix critical trace warnings to allow ONNX export

* Force input to `sqrt` to be float type

* Cleanup code

* Remove unused import statement

* Update model sew

* Small refactor

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Use broadcasting instead of repeat

* Implement suggestion

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* Match deberta v2 changes in sew_d

* Improve code quality

* Update code quality

* Consistency of small refactor

* Match changes in sew_d

Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>

* [FX] _generate_dummy_input supports audio-classification models for labels (#18580)

* Support audio classification architectures for labels generation, as well as provides a flag to print warnings or not

* Use ENV_VARS_TRUE_VALUES

* Fix docstrings with last version of hf-doc-builder styler (#18581)

* Fix docstrings with last version of hf-doc-builder styler

* Remove empty Parameter block

* Bump nbconvert from 6.0.1 to 6.3.0 in /examples/research_projects/lxmert (#18565)

Bumps [nbconvert](https://github.com/jupyter/nbconvert) from 6.0.1 to 6.3.0.
- [Release notes](https://github.com/jupyter/nbconvert/releases)
- [Commits](jupyter/nbconvert@6.0.1...6.3.0)

---
updated-dependencies:
- dependency-name: nbconvert
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* Bump nbconvert in /examples/research_projects/visual_bert (#18566)

Bumps [nbconvert](https://github.com/jupyter/nbconvert) from 6.0.1 to 6.3.0.
- [Release notes](https://github.com/jupyter/nbconvert/releases)
- [Commits](jupyter/nbconvert@6.0.1...6.3.0)

---
updated-dependencies:
- dependency-name: nbconvert
  dependency-type: direct:production
...

Signed-off-by: dependabot[bot] <support@github.com>

Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>

* fix owlvit tests, update docstring examples (#18586)

* Return the permuted hidden states if return_dict=True (#18578)

* Load sharded pt to flax (#18419)

* initial commit

* add small test

* add cross pt tf flag to test

* fix quality

* style

* update test with new repo

* fix failing test

* update

* fix wrong param ordering

* style

* update based on review

* update related to recent new caching mechanism

* quality

* Update based on review

Co-authored-by: sgugger <sylvain.gugger@gmail.com>

* quality and style

* Update src/transformers/modeling_flax_utils.py
Co-authored-by: sgugger <sylvain.gugger@gmail.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Add type hints for ViLT models (#18577)

* Add type hints for Vilt models

* Add missing return type for TokenClassification class

* update doc for perf_train_cpu_many, add intel mpi introduction (#18576)

* update doc for perf_train_cpu_many, add mpi introduction

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

* Update docs/source/en/perf_train_cpu_many.mdx

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Update docs/source/en/perf_train_cpu_many.mdx

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* typos (#18594)

* FSDP bug fix for `load_state_dict` (#18596)

* Add `TFAutoModelForSemanticSegmentation` to the main `__init__.py` (#18600)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Generate: validate `model_kwargs` (and catch typos in generate arguments) (#18261)

* validate generate model_kwargs

* generate tests -- not all models have an attn mask

* Supporting seq2seq models for `bitsandbytes` integration (#18579)

* Supporting seq2seq models for `bitsandbytes` integration

- `bitsandbytes` integration supports now seq2seq models
- check if a model has tied weights as an additional check

* small modification

- tie the weights before looking at tied weights!

* Add Donut (#18488)

* First draft

* Improve script

* Update script

* Make conversion work

* Add final_layer_norm attribute to Swin's config

* Add DonutProcessor

* Convert more models

* Improve feature extractor and convert base models

* Fix bug

* Improve integration tests

* Improve integration tests and add model to README

* Add doc test

* Add feature extractor to docs

* Fix integration tests

* Remove register_buffer

* Fix toctree and add missing attribute

* Add DonutSwin

* Make conversion script work

* Improve conversion script

* Address comment

* Fix bug

* Fix another bug

* Remove deprecated method from docs

* Make Swin and Swinv2 untouched

* Fix code examples

* Fix processor

* Update model_type to donut-swin

* Add feature extractor tests, add token2json method, improve feature extractor

* Fix failing tests, remove integration test

* Add do_thumbnail for consistency

* Improve code examples

* Add code example for document parsing

* Add DonutSwin to MODEL_NAMES_MAPPING

* Add model to appropriate place in toctree

* Update namespace to appropriate organization

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>

* Fix URLs (#18604)

Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>

* Update BLOOM parameter counts (#18531)

* Update BLOOM parameter counts

* Update BLOOM parameter counts

* [doc] fix anchors (#18591)

the manual anchors end up being duplicated with automatically added anchors and no longer work.

* [fsmt] deal with -100 indices in decoder ids (#18592)

* [fsmt] deal with -100 indices in decoder ids

Fixes: #17945

decoder ids get the default index -100, which breaks the model - like t5 and many other models add a fix to replace -100 with the correct pad index. 

For some reason this use case hasn't been used with this model until recently - so this issue was there since the beginning it seems.

Any suggestions to how to add a simple test here? or perhaps we have something similar already? user's script is quite massive.

* style

* small change (#18584)

* Flax Remat for LongT5 (#17994)

* [Flax] Add remat (gradient checkpointing)

* fix variable naming in test

* flip: checkpoint using a method

* fix naming

* fix class naming

* apply PVP's suggestions from code review

* add gradient_checkpointing to examples

* Add gradient_checkpointing to run_mlm_flax

* Add remat to longt5

* Add gradient checkpointing test longt5

* Fix args errors

* Fix remaining tests

* Make fixup & quality fixes

* replace kwargs

* remove unecessary kwargs

* Make fixup changes

* revert long_t5_flax changes

* Remove return_dict and copy to LongT5

* Remove test_gradient_checkpointing

Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>

* mac m1 `mps` integration (#18598)

* mac m1 `mps` integration

* Update docs/source/en/main_classes/trainer.mdx

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* addressing comments

* Apply suggestions from code review

Co-authored-by: Dan Saattrup Nielsen <47701536+saattrupdan@users.noreply.github.com>

* resolve comment

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Dan Saattrup Nielsen <47701536+saattrupdan@users.noreply.github.com>

* Change scheduled CIs to use torch 1.12.1 (#18644)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Add checks for some workflow jobs (#18583)

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* TF: Fix generation repetition penalty with XLA (#18648)

* Update longt5.mdx (#18634)

* Update run_translation_no_trainer.py (#18637)

* Update run_translation_no_trainer.py

found an error in selecting `no_decay` parameters and some small modifications when the user continues to train from a checkpoint

* fixs `no_decay` and `resume_step` issue

1. change `no_decay` list
2. if use continue to train their model from provided checkpoint, the `resume_step` will not be initialized properly if `args.gradient_accumulation_steps != 1`

* [bnb] Minor modifications (#18631)

* bnb minor modifications

- refactor documentation
- add troubleshooting README
- add PyPi library on DockerFile

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Apply suggestions from code review

* Apply suggestions from code review

* Apply suggestions from code review

* put in one block

- put bash instructions in one block

* update readme

- refactor a bit hardware requirements

* change text a bit

* Apply suggestions from code review

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* apply suggestions

Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* add link to paper

* Apply suggestions from code review

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* Update tests/mixed_int8/README.md

* Apply suggestions from code review

* refactor a bit

* add instructions Turing & Amperer

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>

* add A6000

* clarify a bit

* remove small part

* Update tests/mixed_int8/README.md

Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>

* Examples: add Bloom support for token classification (#18632)

* examples: add Bloom support for token classification (FLAX, PyTorch and TensorFlow)

* examples: remove support for Bloom in token classication (FLAX and TensorFlow currently have no support for it)

* Fix Yolos ONNX export test (#18606)

Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>

* Fixup

* Fix up

* Move PIL default arguments inside function for safe imports

* Add image utils to toctree

* Update `rescale` method to reflect changes in #18677

* Update docs/source/en/internal/image_processing_utils.mdx

Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>

* Address Niels PR comments

* Apply suggestions from code review - remove defaults to None

Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>

Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>

* Fix docstrings and revert to PIL.Image.XXX resampling

Use PIL.Image.XXX resampling values instead of PIL.Image.Resampling.XXX enum as it's only in the recent version >= 9.10 and version is not yet pinned and older version support deprecated

* Some more docstrings and PIL.Image tidy up

* Reorganise arguments so flags by modifiers

* Few last docstring fixes

Signed-off-by: Seunghwan Hong <seunghwan@scatterlab.co.kr>
Signed-off-by: dependabot[bot] <support@github.com>
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
Co-authored-by: Amy Roberts <amyeroberts@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <35901082+sgugger@users.noreply.github.com>
Co-authored-by: Alara Dirik <8944735+alaradirik@users.noreply.github.com>
Co-authored-by: NielsRogge <48327001+NielsRogge@users.noreply.github.com>
Co-authored-by: Seunghwan Hong <harrydrippin@gmail.com>
Co-authored-by: Yih-Dar <2521628+ydshieh@users.noreply.github.com>
Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Co-authored-by: Sylvain Gugger <Sylvain.gugger@gmail.com>
Co-authored-by: Julien Chaumond <julien@huggingface.co>
Co-authored-by: regisss <15324346+regisss@users.noreply.github.com>
Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
Co-authored-by: Sourab Mangrulkar <13534540+pacman100@users.noreply.github.com>
Co-authored-by: Ankur Goyal <ankrgyl@gmail.com>
Co-authored-by: Ankur Goyal <ankur@impira.com>
Co-authored-by: Steven Liu <59462357+stevhliu@users.noreply.github.com>
Co-authored-by: Niels Rogge <nielsrogge@Nielss-MacBook-Pro.local>
Co-authored-by: Mishig Davaadorj <dmishig@gmail.com>
Co-authored-by: Rasmus Arpe Fogh Jensen <Rasmus.arpe@gmail.com>
Co-authored-by: Ian Castillo <7807897+donelianc@users.noreply.github.com>
Co-authored-by: AguilaCudicio <aguila.cudicio@gmail.com>
Co-authored-by: Omar U. Espejel <espejelomar@gmail.com>
Co-authored-by: Younes Belkada <49240599+younesbelkada@users.noreply.github.com>
Co-authored-by: Thomas Wang <24695242+thomasw21@users.noreply.github.com>
Co-authored-by: Niklas Hansson <niklas.sven.hansson@gmail.com>
Co-authored-by: Thomas Chaigneau <t.chaigneau.tc@gmail.com>
Co-authored-by: YouJiacheng <1503679330@qq.com>
Co-authored-by: Michael Benayoun <mickbenayoun@gmail.com>
Co-authored-by: Joao Gante <joaofranciscocardosogante@gmail.com>
Co-authored-by: Matt <Rocketknight1@users.noreply.github.com>
Co-authored-by: Dhruv Karan <k4r4n.dhruv@gmail.com>
Co-authored-by: Michael Wyatt <mrwyattii@gmail.com>
Co-authored-by: Maxime G <joihn@users.noreply.github.com>
Co-authored-by: Maxime Gardoni <maxime.gardoni@ecorobotix.com>
Co-authored-by: Wonseok Lee (Jack) <rollerkid02@snu.ac.kr>
Co-authored-by: Dan Jones <dan.j.jones2@gmail.com>
Co-authored-by: Daniel Jones <jonesdaniel@microsoft.com>
Co-authored-by: flozi00 <flozi00.fz@gmail.com>
Co-authored-by: iiLaurens <iiLaurens@users.noreply.github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com>
Co-authored-by: Wang, Yi <yi.a.wang@intel.com>
Co-authored-by: Stas Bekman <stas00@users.noreply.github.com>
Co-authored-by: Niklas Muennighoff <n.muennighoff@gmail.com>
Co-authored-by: Karim Foda <35491698+KMFODA@users.noreply.github.com>
Co-authored-by: sanchit-gandhi <sanchit@huggingface.co>
Co-authored-by: Dan Saattrup Nielsen <47701536+saattrupdan@users.noreply.github.com>
Co-authored-by: zhoutang776 <47708118+zhoutang776@users.noreply.github.com>
Co-authored-by: Stefan Schweter <stefan@schweter.it>
Co-authored-by: lewtun <lewis.c.tunstall@gmail.com>
  • Loading branch information
Show file tree
Hide file tree
Showing 13 changed files with 893 additions and 129 deletions.
2 changes: 2 additions & 0 deletions docs/source/en/_toctree.yml
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,8 @@
title: Utilities for Trainer
- local: internal/generation_utils
title: Utilities for Generation
- local: internal/image_processing_utils
title: Utilities for Image Processors
- local: internal/file_utils
title: General Utilities
title: Internal Helpers
Expand Down
30 changes: 30 additions & 0 deletions docs/source/en/internal/image_processing_utils.mdx
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
<!--Copyright 2022 The HuggingFace Team. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
-->

# Utilities for Image Processors

This page lists all the utility functions used by the image processors, mainly the functional
transformations used to process the images.

Most of those are only useful if you are studying the code of the image processors in the library.

## Image Transformations

[[autodoc]] image_transforms.rescale

[[autodoc]] image_transforms.resize

[[autodoc]] image_transforms.to_pil_image

## ImageProcessorMixin

[[autodoc]] image_processing_utils.ImageProcessorMixin
4 changes: 4 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,8 @@
name for name in dir(dummy_vision_objects) if not name.startswith("_")
]
else:
_import_structure["image_processing_utils"] = ["ImageProcessorMixin"]
_import_structure["image_transforms"] = ["rescale", "resize", "to_pil_image"]
_import_structure["image_utils"] = ["ImageFeatureExtractionMixin"]
_import_structure["models.beit"].append("BeitFeatureExtractor")
_import_structure["models.clip"].append("CLIPFeatureExtractor")
Expand Down Expand Up @@ -3648,6 +3650,8 @@
except OptionalDependencyNotAvailable:
from .utils.dummy_vision_objects import *
else:
from .image_processing_utils import ImageProcessorMixin
from .image_transforms import rescale, resize, to_pil_image
from .image_utils import ImageFeatureExtractionMixin
from .models.beit import BeitFeatureExtractor
from .models.clip import CLIPFeatureExtractor
Expand Down
54 changes: 54 additions & 0 deletions src/transformers/image_processing_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from .feature_extraction_utils import BatchFeature as BaseBatchFeature
from .feature_extraction_utils import FeatureExtractionMixin
from .utils import logging


logger = logging.get_logger(__name__)


# TODO: Move BatchFeature to be imported by both feature_extraction_utils and image_processing_utils
# We override the class string here, but logic is the same.
class BatchFeature(BaseBatchFeature):
r"""
Holds the output of the image processor specific `__call__` methods.
This class is derived from a python dictionary and can be used as a dictionary.
Args:
data (`dict`):
Dictionary of lists/arrays/tensors returned by the __call__ method ('pixel_values', etc.).
tensor_type (`Union[None, str, TensorType]`, *optional*):
You can give a tensor_type here to convert the lists of integers in PyTorch/TensorFlow/Numpy Tensors at
initialization.
"""


# We use aliasing whilst we phase out the old API. Once feature extractors for vision models
# are deprecated, ImageProcessor mixin will be implemented. Any shared logic will be abstracted out.
ImageProcessorMixin = FeatureExtractionMixin


class BaseImageProcessor(ImageProcessorMixin):
def __init__(self, **kwargs):
super().__init__(**kwargs)

def __call__(self, images, **kwargs) -> BatchFeature:
return self.preprocess(images, **kwargs)

def preprocess(self, images, **kwargs) -> BatchFeature:
raise NotImplementedError("Each image processor must implement its own preprocess method")
259 changes: 259 additions & 0 deletions src/transformers/image_transforms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,259 @@
# coding=utf-8
# Copyright 2022 The HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import TYPE_CHECKING, List, Optional, Tuple, Union

import numpy as np

from transformers.utils.import_utils import is_flax_available, is_tf_available, is_torch_available, is_vision_available


if is_vision_available():
import PIL

from .image_utils import (
ChannelDimension,
get_image_size,
infer_channel_dimension_format,
is_jax_tensor,
is_tf_tensor,
is_torch_tensor,
)


if TYPE_CHECKING:
if is_torch_available():
import torch
if is_tf_available():
import tensorflow as tf
if is_flax_available():
import jax.numpy as jnp


def to_channel_dimension_format(image: np.ndarray, channel_dim: Union[ChannelDimension, str]) -> np.ndarray:
"""
Converts `image` to the channel dimension format specified by `channel_dim`.
Args:
image (`numpy.ndarray`):
The image to have its channel dimension set.
channel_dim (`ChannelDimension`):
The channel dimension format to use.
Returns:
`np.ndarray`: The image with the channel dimension set to `channel_dim`.
"""
if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")

current_channel_dim = infer_channel_dimension_format(image)
target_channel_dim = ChannelDimension(channel_dim)
if current_channel_dim == target_channel_dim:
return image

if target_channel_dim == ChannelDimension.FIRST:
image = image.transpose((2, 0, 1))
elif target_channel_dim == ChannelDimension.LAST:
image = image.transpose((1, 2, 0))
else:
raise ValueError("Unsupported channel dimension format: {}".format(channel_dim))

return image


def rescale(
image: np.ndarray, scale: float, data_format: Optional[ChannelDimension] = None, dtype=np.float32
) -> np.ndarray:
"""
Rescales `image` by `scale`.
Args:
image (`np.ndarray`):
The image to rescale.
scale (`float`):
The scale to use for rescaling the image.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the image. If not provided, it will be the same as the input image.
dtype (`np.dtype`, *optional*, defaults to `np.float32`):
The dtype of the output image. Defaults to `np.float32`. Used for backwards compatibility with feature
extractors.
Returns:
`np.ndarray`: The rescaled image.
"""
if not isinstance(image, np.ndarray):
raise ValueError(f"Input image must be of type np.ndarray, got {type(image)}")

rescaled_image = image * scale
if data_format is not None:
rescaled_image = to_channel_dimension_format(rescaled_image, data_format)
rescaled_image = rescaled_image.astype(dtype)
return rescaled_image


def to_pil_image(
image: Union[np.ndarray, PIL.Image.Image, "torch.Tensor", "tf.Tensor", "jnp.Tensor"],
do_rescale: Optional[bool] = None,
) -> PIL.Image.Image:
"""
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if
needed.
Args:
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor` or `tf.Tensor`):
The image to convert to the `PIL.Image` format.
do_rescale (`bool`, *optional*):
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will default
to `True` if the image type is a floating type, `False` otherwise.
Returns:
`PIL.Image.Image`: The converted image.
"""
if isinstance(image, PIL.Image.Image):
return image

# Convert all tensors to numpy arrays before converting to PIL image
if is_torch_tensor(image) or is_tf_tensor(image):
image = image.numpy()
elif is_jax_tensor(image):
image = np.array(image)
elif not isinstance(image, np.ndarray):
raise ValueError("Input image type not supported: {}".format(type(image)))

# If the channel as been moved to first dim, we put it back at the end.
image = to_channel_dimension_format(image, ChannelDimension.LAST)

# PIL.Image can only store uint8 values, so we rescale the image to be between 0 and 255 if needed.
do_rescale = isinstance(image.flat[0], float) if do_rescale is None else do_rescale
if do_rescale:
image = rescale(image, 255)
image = image.astype(np.uint8)
return PIL.Image.fromarray(image)


def get_resize_output_image_size(
input_image: np.ndarray,
size: Union[int, Tuple[int, int], List[int], Tuple[int]],
default_to_square: bool = True,
max_size: Optional[int] = None,
) -> tuple:
"""
Find the target (height, width) dimension of the output image after resizing given the input image and the desired
size.
Args:
input_image (`np.ndarray`):
The image to resize.
size (`int` or `Tuple[int, int]` or List[int] or Tuple[int]):
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be matched to
this.
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to this
number. i.e, if height > width, then image will be rescaled to (size * height / width, size).
default_to_square (`bool`, *optional*, defaults to `True`):
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a square
(`size`,`size`). If set to `False`, will replicate
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize)
with support for resizing only the smallest edge and providing an optional `max_size`.
max_size (`int`, *optional*):
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is greater
than `max_size` after being resized according to `size`, then the image is resized again so that the longer
edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller edge may be shorter
than `size`. Only used if `default_to_square` is `False`.
Returns:
`tuple`: The target (height, width) dimension of the output image after resizing.
"""
if isinstance(size, (tuple, list)):
if len(size) == 2:
return tuple(size)
elif len(size) == 1:
# Perform same logic as if size was an int
size = size[0]
else:
raise ValueError("size must have 1 or 2 elements if it is a list or tuple")

if default_to_square:
return (size, size)

height, width = get_image_size(input_image)
short, long = (width, height) if width <= height else (height, width)
requested_new_short = size

if short == requested_new_short:
return (height, width)

new_short, new_long = requested_new_short, int(requested_new_short * long / short)

if max_size is not None:
if max_size <= requested_new_short:
raise ValueError(
f"max_size = {max_size} must be strictly greater than the requested "
f"size for the smaller edge size = {size}"
)
if new_long > max_size:
new_short, new_long = int(max_size * new_short / new_long), max_size

return (new_long, new_short) if width <= height else (new_short, new_long)


def resize(
image,
size: Tuple[int, int],
resample=PIL.Image.BILINEAR,
data_format: Optional[ChannelDimension] = None,
return_numpy: bool = True,
) -> np.ndarray:
"""
Resizes `image` to (h, w) specified by `size` using the PIL library.
Args:
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`):
The image to resize.
size (`Tuple[int, int]`):
The size to use for resizing the image.
resample (`int`, *optional*, defaults to `PIL.Image.BILINEAR`):
The filter to user for resampling.
data_format (`ChannelDimension`, *optional*):
The channel dimension format of the output image. If `None`, will use the inferred format from the input.
return_numpy (`bool`, *optional*, defaults to `True`):
Whether or not to return the resized image as a numpy array. If False a `PIL.Image.Image` object is
returned.
Returns:
`np.ndarray`: The resized image.
"""
if not len(size) == 2:
raise ValueError("size must have 2 elements")

# For all transformations, we want to keep the same data format as the input image unless otherwise specified.
# The resized image from PIL will always have channels last, so find the input format first.
data_format = infer_channel_dimension_format(image) if data_format is None else data_format

# To maintain backwards compatibility with the resizing done in previous image feature extractors, we use
# the pillow library to resize the image and then convert back to numpy
if not isinstance(image, PIL.Image.Image):
# PIL expects image to have channels last
image = to_channel_dimension_format(image, ChannelDimension.LAST)
image = to_pil_image(image)
height, width = size
# PIL images are in the format (width, height)
resized_image = image.resize((width, height), resample=resample)

if return_numpy:
resized_image = np.array(resized_image)
resized_image = to_channel_dimension_format(resized_image, data_format)
return resized_image
Loading

0 comments on commit 1973b77

Please sign in to comment.