Skip to content

Commit

Permalink
[Tools]: MAE Reconstructed Image Visualization (open-mmlab#376)
Browse files Browse the repository at this point in the history
* [Tools]: MAE Reconstructed Image Visualization]

* [Fix]: fix docstring and type hint

* [Fix]: fix docstring in MAE clsss

* [Fix]: fix docstring in MAE clsss

* [Fix]: fix type hint

* [Fix]: fix type hint and docstring

* [refactor]: refactor super init
  • Loading branch information
soonera authored and fangyixiao18 committed Jul 28, 2022
1 parent 074ae09 commit 3f530f0
Show file tree
Hide file tree
Showing 11 changed files with 638 additions and 24 deletions.
268 changes: 268 additions & 0 deletions demo/mae_visualization.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,268 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Copyright (c) OpenMMLab. All rights reserved.\n",
"\n",
"Copyright (c) Meta Platforms, Inc. and affiliates.\n",
"\n",
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n",
"\n",
"## Masked Autoencoders: Visualization Demo\n",
"\n",
"This is a visualization demo using our pre-trained MAE models. No GPU is needed."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Prepare\n",
"Check environment. Install packages if in Colab."
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"import sys\n",
"import os\n",
"import requests\n",
"\n",
"import torch\n",
"import numpy as np\n",
"\n",
"import matplotlib.pyplot as plt\n",
"from PIL import Image\n",
"\n",
"from mmselfsup.models import build_algorithm\n",
"\n",
"# check whether run in Colab\n",
"if 'google.colab' in sys.modules:\n",
" print('Running in Colab.')\n",
" !pip install openmim\n",
" !mim install mmcv-full\n",
" !git clone https://github.com/open-mmlab/mmselfsup.git\n",
" %cd mmselfsup/\n",
" !pip install -e .\n",
" sys.path.append('./mmselfsup')\n",
" %cd demo\n",
"else:\n",
" sys.path.append('..')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Define utils"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"# define the utils\n",
"\n",
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n",
"imagenet_std = np.array([0.229, 0.224, 0.225])\n",
"\n",
"def show_image(image, title=''):\n",
" # image is [H, W, 3]\n",
" assert image.shape[2] == 3\n",
" image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n",
" plt.imshow(image)\n",
" plt.title(title, fontsize=16)\n",
" plt.axis('off')\n",
" return\n",
"\n",
"\n",
"def show_images(x, im_masked, y, im_paste):\n",
" # make the plt figure larger\n",
" plt.rcParams['figure.figsize'] = [24, 6]\n",
"\n",
" plt.subplot(1, 4, 1)\n",
" show_image(x, \"original\")\n",
"\n",
" plt.subplot(1, 4, 2)\n",
" show_image(im_masked, \"masked\")\n",
"\n",
" plt.subplot(1, 4, 3)\n",
" show_image(y, \"reconstruction\")\n",
"\n",
" plt.subplot(1, 4, 4)\n",
" show_image(im_paste, \"reconstruction + visible\")\n",
"\n",
" plt.show()\n",
"\n",
"\n",
"def post_process(x, y, mask):\n",
" x = torch.einsum('nchw->nhwc', x.cpu())\n",
" # masked image\n",
" im_masked = x * (1 - mask)\n",
" # MAE reconstruction pasted with visible patches\n",
" im_paste = x * (1 - mask) + y * mask\n",
" return x[0], im_masked[0], y[0], im_paste[0]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load an image"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"# load an image\n",
"img_url = 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'\n",
"img_pil = Image.open(requests.get(img_url, stream=True).raw)\n",
"img = img_pil.resize((224, 224))\n",
"img = np.array(img) / 255.\n",
"\n",
"assert img.shape == (224, 224, 3)\n",
"\n",
"# normalize by ImageNet mean and std\n",
"img = img - imagenet_mean\n",
"img = img / imagenet_std\n",
"\n",
"plt.rcParams['figure.figsize'] = [5, 5]\n",
"show_image(torch.tensor(img))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Load a pre-trained MAE model"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"%%writefile ../configs/selfsup/mae/mae_visualization.py\n",
"model = dict(\n",
" type='MAE',\n",
" backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n",
" neck=dict(\n",
" type='MAEPretrainDecoder',\n",
" patch_size=16,\n",
" in_chans=3,\n",
" embed_dim=1024,\n",
" decoder_embed_dim=512,\n",
" decoder_depth=8,\n",
" decoder_num_heads=16,\n",
" mlp_ratio=4.,\n",
" ),\n",
" head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16))\n",
"\n",
"img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
"# dataset summary\n",
"data = dict(\n",
" test=dict(\n",
" pipeline = [\n",
" dict(type='Resize', size=(224, 224)),\n",
" dict(type='ToTensor'),\n",
" dict(type='Normalize', **img_norm_cfg),]\n",
" ))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n",
"\n",
"# download checkpoint if not exist\n",
"# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n",
"!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from mmselfsup.apis import init_model\n",
"ckpt_path = \"mae_visualize_vit_large.pth\"\n",
"model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n",
"print('Model loaded.')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Run MAE on the image"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [],
"source": [
"# make random mask reproducible (comment out to make it change)\n",
"torch.manual_seed(2)\n",
"print('MAE with pixel reconstruction:')\n",
"\n",
"from mmselfsup.apis import inference_model\n",
"\n",
"img_url = 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'\n",
"img = Image.open(requests.get(img_url, stream=True).raw)\n",
"img, (mask, pred) = inference_model(model, img)\n",
"x, im_masked, y, im_paste = post_process(img, pred, mask)\n",
"show_images(x, im_masked, y, im_paste)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.13"
},
"vscode": {
"interpreter": {
"hash": "3e4aeeccd14e965f43d0896afbaf8d71604e66b8605affbaa33ec76aa4083757"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}
21 changes: 21 additions & 0 deletions docs/en/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,27 @@ Arguments:
- `WORK_DIR`: the directory to save the results of visualization.
- `[optional arguments]`: for optional arguments, you can refer to [visualize_tsne.py](https://github.com/open-mmlab/mmselfsup/blob/master/tools/analysis_tools/visualize_tsne.py)

### MAE Visualization

We provide a tool to visualize the mask and reconstruction image of MAE model.

```shell
python tools/misc/mae_visualization.py ${IMG} ${CONFIG_FILE} ${CKPT_PATH} --device ${DEVICE}
```

参数:

- `IMG`: an image path used for visualization.
- `CONFIG_FILE`: config file for the pre-trained model.
- `CKPT_PATH`: the path of model's checkpoint.
- `DEVICE`: device used for inference.

An example:

```shell
python tools/misc/mae_visualization.py tests/data/color.jpg configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py mae_epoch_400.pth --device 'cuda:0'
```

### Reproducibility

If you want to make your performance exactly reproducible, please switch on `--deterministic` to train the final model to be published. Note that this flag will switch off `torch.backends.cudnn.benchmark` and slow down the training speed.
21 changes: 21 additions & 0 deletions docs/zh_cn/get_started.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,27 @@ python tools/analysis_tools/visualize_tsne.py ${CONFIG_FILE} --checkpoint ${CKPT
- `WORK_DIR`: 保存可视化结果的路径.
- `[optional arguments]`: 可选参数,具体可以参考 [visualize_tsne.py](../../tools/analysis_tools/visualize_tsne.py)

### MAE 可视化

我们提供了一个对 MAE 掩码效果和重建效果可视化可视化的方法:

```shell
python tools/misc/mae_visualization.py ${IMG} ${CONFIG_FILE} ${CKPT_PATH} --device ${DEVICE}
```

参数:

- `IMG`: 用于可视化的图片
- `CONFIG_FILE`: 训练预训练模型的参数配置文件.
- `CKPT_PATH`: 预训练模型的路径.
- `DEVICE`: 用于推理的设备.

示例:

```shell
python tools/misc/mae_visualization.py tests/data/color.jpg configs/selfsup/mae/mae_vit-base-p16_8xb512-coslr-400e_in1k.py mae_epoch_400.pth --device 'cuda:0'
```

### 可复现性

如果您想确保模型精度的可复现性,您可以设置 `--deterministic` 参数。但是,开启 `--deterministic` 意味着关闭 `torch.backends.cudnn.benchmark`, 所以会使模型的训练速度变慢。
6 changes: 5 additions & 1 deletion mmselfsup/apis/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .inference import inference_model, init_model
from .train import init_random_seed, set_random_seed, train_model

__all__ = ['init_random_seed', 'set_random_seed', 'train_model']
__all__ = [
'init_random_seed', 'inference_model', 'set_random_seed', 'train_model',
'init_model'
]
Loading

0 comments on commit 3f530f0

Please sign in to comment.