forked from facebookresearch/deit
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit 1d38fa4
Showing
17 changed files
with
1,672 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
version: 2.1 | ||
|
||
jobs: | ||
python_lint: | ||
docker: | ||
- image: circleci/python:3.7 | ||
steps: | ||
- checkout | ||
- run: | ||
command: | | ||
pip install --user --progress-bar off flake8 typing | ||
flake8 . | ||
test: | ||
docker: | ||
- image: circleci/python:3.7 | ||
steps: | ||
- checkout | ||
- run: | ||
command: | | ||
pip install --user --progress-bar off pytest | ||
pip install --user --progress-bar off torch torchvision | ||
pip install --user --progress-bar off timm==0.3.2 | ||
pytest . | ||
workflows: | ||
build: | ||
jobs: | ||
- python_lint |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
# Code of Conduct | ||
|
||
Facebook has adopted a Code of Conduct that we expect project participants to adhere to. | ||
Please read the [full text](https://code.fb.com/codeofconduct/) | ||
so that you can understand what actions will and will not be tolerated. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,39 @@ | ||
# Contributing to DeiT | ||
We want to make contributing to this project as easy and transparent as | ||
possible. | ||
|
||
## Our Development Process | ||
Minor changes and improvements will be released on an ongoing basis. Larger changes (e.g., changesets implementing a new paper) will be released on a more periodic basis. | ||
|
||
## Pull Requests | ||
We actively welcome your pull requests. | ||
|
||
1. Fork the repo and create your branch from `main`. | ||
2. If you've added code that should be tested, add tests. | ||
3. If you've changed APIs, update the documentation. | ||
4. Ensure the test suite passes. | ||
5. Make sure your code lints. | ||
6. If you haven't already, complete the Contributor License Agreement ("CLA"). | ||
|
||
## Contributor License Agreement ("CLA") | ||
In order to accept your pull request, we need you to submit a CLA. You only need | ||
to do this once to work on any of Facebook's open source projects. | ||
|
||
Complete your CLA here: <https://code.facebook.com/cla> | ||
|
||
## Issues | ||
We use GitHub issues to track public bugs. Please ensure your description is | ||
clear and has sufficient instructions to be able to reproduce the issue. | ||
|
||
Facebook has a [bounty program](https://www.facebook.com/whitehat/) for the safe | ||
disclosure of security bugs. In those cases, please go through the process | ||
outlined on that page and do not file a public issue. | ||
|
||
## Coding Style | ||
* 4 spaces for indentation rather than tabs | ||
* 80 character line length | ||
* PEP8 formatting following [Black](https://black.readthedocs.io/en/stable/) | ||
|
||
## License | ||
By contributing to DeiT, you agree that your contributions will be licensed | ||
under the LICENSE file in the root directory of this source tree. |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
*.swp | ||
**/__pycache__/** | ||
imnet_resnet50_scratch/timm_temp/ | ||
.dumbo.json |
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
# DeiT: Data-efficient Image Transformers | ||
|
||
This repository contains PyTorch evaluation code, training code and pretrained models for DeiT (Data-Efficient Image Transformers). | ||
|
||
They obtain competitive tradeoffs in terms of speed / precision: | ||
|
||
![DeiT](.github/deit.png) | ||
|
||
For details see [Training data-efficient image transformers & distillation through attention](link/to/paper) by Hugo Touvron, Matthieu Cord, Matthijs Douze, Francisco Massa, Alexandre Sablayrolles and Hervé Jégou. | ||
|
||
``` | ||
@article{touvron2020deit, | ||
title={Training data-efficient image transformers & distillation through attention}, | ||
author={Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Herv\'e J\'egou}, | ||
journal={arXiv preprint arXiv:1912.11370}, | ||
year={2020} | ||
} | ||
``` | ||
|
||
# Model Zoo | ||
|
||
We provide baseline DeiT models pretrained on ImageNet 2012. | ||
|
||
| name | acc@1 | acc@5 | url | | ||
| --- | --- | --- | --- | | ||
| DeiT-tiny | 72.2 | 91.1 | [model](https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth) | | ||
| DeiT-small | 79.9 | 95.0 | [model](https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth) | | ||
| DeiT-base | 81.8 | 95.6 | [model](https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth) | | ||
|
||
|
||
The models are also available via torch hub. | ||
Before using it, make sure you have [`timm==0.3.2`](https://github.com/rwightman/pytorch-image-models) installed. To load DeiT-base with pretrained weights on ImageNet simply do: | ||
|
||
```python | ||
import torch | ||
# check you have the right version of timm | ||
import timm | ||
assert timm.__version__ == "0.3.2" | ||
|
||
# now load it with torchhub | ||
model = torch.hub.load('facebookresearch/deit', 'deit_base_patch16_224', pretrained=True) | ||
``` | ||
|
||
# Usage | ||
|
||
First, clone the repository locally: | ||
``` | ||
git clone https://github.com/facebookresearch/deit.git | ||
``` | ||
Then, install PyTorch 1.7.0+ and torchvision 0.8.1+ and [pytorch-image-models 0.3.2](https://github.com/rwightman/pytorch-image-models): | ||
|
||
``` | ||
conda install -c pytorch pytorch torchvision | ||
pip install timm==0.3.2 | ||
``` | ||
|
||
## Data preparation | ||
|
||
Download and extract ImageNet train and val images from http://image-net.org/. | ||
The directory structure is the standard layout for the torchvision [`datasets.ImageFolder`](https://pytorch.org/docs/stable/torchvision/datasets.html#imagefolder), and the training and validation data is expected to be in the `train/` folder and `val` folder respectively: | ||
|
||
``` | ||
/path/to/imagenet/ | ||
train/ | ||
class1/ | ||
img1.jpeg | ||
class2/ | ||
img2.jpeg | ||
val/ | ||
class1/ | ||
img3.jpeg | ||
class/2 | ||
img4.jpeg | ||
``` | ||
|
||
## Evaluation | ||
To evaluate a pre-trained DeiT-base on ImageNet val with a single GPU run: | ||
``` | ||
python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth --data-path /path/to/imagenet | ||
``` | ||
This should give | ||
``` | ||
* Acc@1 81.846 Acc@5 95.594 loss 0.820 | ||
``` | ||
|
||
For Deit-small, run: | ||
``` | ||
python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth --model deit_small_patch16_224 --data-path /path/to/imagenet | ||
``` | ||
giving | ||
``` | ||
* Acc@1 79.854 Acc@5 94.968 loss 0.881 | ||
``` | ||
|
||
And for Deit-tiny: | ||
``` | ||
python main.py --eval --resume https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth --model deit_tiny_patch16_224 --data-path /path/to/imagenet | ||
``` | ||
which should give | ||
``` | ||
* Acc@1 72.202 Acc@5 91.124 loss 1.219 | ||
``` | ||
|
||
## Training | ||
To train DeiT-small and Deit-tiny on ImageNet on a single node with 4 gpus for 300 epochs run: | ||
|
||
DeiT-small | ||
``` | ||
python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model deit_small_patch16_224 --batch-size 256 --data-path /path/to/imagenet | ||
``` | ||
|
||
DeiT-tiny | ||
``` | ||
python -m torch.distributed.launch --nproc_per_node=4 --use_env main.py --model deit_tiny_patch16_224 --batch-size 256 --data-path /path/to/imagenet | ||
``` | ||
|
||
### Multinode training | ||
|
||
Distributed training is available via Slurm and [submitit](https://github.com/facebookincubator/submitit): | ||
|
||
``` | ||
pip install submitit | ||
``` | ||
|
||
To train DeiT-base model on ImageNet on 2 nodes with 8 gpus each for 300 epochs: | ||
|
||
``` | ||
python run_with_submitit.py --model deit_base_patch16_224 --data-path /path/to/imagenet | ||
``` | ||
|
||
# License | ||
This repository is released under the CC-BY-NC 4.0. license as found in the [LICENSE](LICENSE) file. | ||
|
||
# Contributing | ||
We actively welcome your pull requests! Please see [CONTRIBUTING.md](.github/CONTRIBUTING.md) and [CODE_OF_CONDUCT.md](.github/CODE_OF_CONDUCT.md) for more info. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright (c) 2015-present, Facebook, Inc. | ||
# All rights reserved. | ||
# | ||
# This source code is licensed under the CC-by-NC license found in the | ||
# LICENSE file in the root directory of this source tree. | ||
# | ||
import math | ||
import os | ||
import json | ||
|
||
import torch | ||
from torchvision import datasets, transforms | ||
from torchvision.datasets.folder import ImageFolder, default_loader | ||
|
||
from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD | ||
from timm.data import create_transform | ||
|
||
|
||
class INatDataset(ImageFolder): | ||
def __init__(self, root, train=True, year=2018, transform=None, target_transform=None, | ||
category='name', loader=default_loader): | ||
self.transform = transform | ||
self.loader = loader | ||
self.target_transform = target_transform | ||
self.year = year | ||
# assert category in ['kingdom','phylum','class','order','supercategory','family','genus','name'] | ||
path_json = os.path.join(root, f'{"train" if train else "val"}{year}.json') | ||
with open(path_json) as json_file: | ||
data = json.load(json_file) | ||
|
||
with open(os.path.join(root, 'categories.json')) as json_file: | ||
data_catg = json.load(json_file) | ||
|
||
path_json_for_targeter = os.path.join(root, f"train{year}.json") | ||
|
||
with open(path_json_for_targeter) as json_file: | ||
data_for_targeter = json.load(json_file) | ||
|
||
targeter = {} | ||
indexer = 0 | ||
for elem in data_for_targeter['annotations']: | ||
king = [] | ||
king.append(data_catg[int(elem['category_id'])][category]) | ||
if king[0] not in targeter.keys(): | ||
targeter[king[0]] = indexer | ||
indexer += 1 | ||
self.nb_classes = len(targeter) | ||
|
||
self.samples = [] | ||
for elem in data['images']: | ||
cut = elem['file_name'].split('/') | ||
target_current = int(cut[2]) | ||
path_current = os.path.join(root, cut[0], cut[2], cut[3]) | ||
|
||
categors = data_catg[target_current] | ||
target_current_true = targeter[categors[category]] | ||
self.samples.append((path_current, target_current_true)) | ||
|
||
# __getitem__ and __len__ inherited from ImageFolder | ||
|
||
|
||
def build_dataset(is_train, args): | ||
transform = build_transform(is_train, args) | ||
|
||
if args.data_set == 'CIFAR': | ||
dataset = datasets.CIFAR100(args.data_path, train=is_train, transform=transform) | ||
nb_classes = 100 | ||
elif args.data_set == 'IMNET': | ||
root = os.path.join(args.data_path, 'train' if is_train else 'val') | ||
dataset = datasets.ImageFolder(root, transform=transform) | ||
nb_classes = 1000 | ||
elif args.data_set == 'INAT': | ||
dataset = INatDataset(args.data_path, train=is_train, year=2018, | ||
category=args.inat_category, transform=transform) | ||
nb_classes = dataset.nb_classes | ||
elif args.data_set == 'INAT19': | ||
dataset = INatDataset(args.data_path, train=is_train, year=2019, | ||
category=args.inat_category, transform=transform) | ||
nb_classes = dataset.nb_classes | ||
|
||
return dataset, nb_classes | ||
|
||
|
||
def build_transform(is_train, args): | ||
resize_im = args.input_size > 32 | ||
if is_train: | ||
# this should always dispatch to transforms_imagenet_train | ||
transform = create_transform( | ||
input_size=args.input_size, | ||
is_training=True, | ||
color_jitter=args.color_jitter, | ||
auto_augment=args.aa, | ||
interpolation=args.train_interpolation, | ||
re_prob=args.reprob, | ||
re_mode=args.remode, | ||
re_count=args.recount, | ||
) | ||
if not resize_im: | ||
# replace RandomResizedCropAndInterpolation with | ||
# RandomCrop | ||
transform.transforms[0] = transforms.RandomCrop( | ||
args.input_size, padding=4) | ||
return transform | ||
|
||
t = [] | ||
if resize_im: | ||
size = int((256 / 224) * args.input_size) | ||
t.append( | ||
transforms.Resize(size, interpolation=3), # to maintain same ratio w.r.t. 224 images | ||
) | ||
t.append(transforms.CenterCrop(args.input_size)) | ||
|
||
t.append(transforms.ToTensor()) | ||
t.append(transforms.Normalize(IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)) | ||
return transforms.Compose(t) |
Oops, something went wrong.