forked from open-mmlab/mmselfsup
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Tools]: MAE Reconstructed Image Visualization (open-mmlab#376)
* [Tools]: MAE Reconstructed Image Visualization] * [Fix]: fix docstring and type hint * [Fix]: fix docstring in MAE clsss * [Fix]: fix docstring in MAE clsss * [Fix]: fix type hint * [Fix]: fix type hint and docstring * [refactor]: refactor super init
- Loading branch information
1 parent
074ae09
commit 3f530f0
Showing
11 changed files
with
638 additions
and
24 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,268 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"Copyright (c) OpenMMLab. All rights reserved.\n", | ||
"\n", | ||
"Copyright (c) Meta Platforms, Inc. and affiliates.\n", | ||
"\n", | ||
"Modified from https://colab.research.google.com/github/facebookresearch/mae/blob/main/demo/mae_visualize.ipynb\n", | ||
"\n", | ||
"## Masked Autoencoders: Visualization Demo\n", | ||
"\n", | ||
"This is a visualization demo using our pre-trained MAE models. No GPU is needed." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Prepare\n", | ||
"Check environment. Install packages if in Colab." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import sys\n", | ||
"import os\n", | ||
"import requests\n", | ||
"\n", | ||
"import torch\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"from PIL import Image\n", | ||
"\n", | ||
"from mmselfsup.models import build_algorithm\n", | ||
"\n", | ||
"# check whether run in Colab\n", | ||
"if 'google.colab' in sys.modules:\n", | ||
" print('Running in Colab.')\n", | ||
" !pip install openmim\n", | ||
" !mim install mmcv-full\n", | ||
" !git clone https://github.com/open-mmlab/mmselfsup.git\n", | ||
" %cd mmselfsup/\n", | ||
" !pip install -e .\n", | ||
" sys.path.append('./mmselfsup')\n", | ||
" %cd demo\n", | ||
"else:\n", | ||
" sys.path.append('..')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Define utils" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# define the utils\n", | ||
"\n", | ||
"imagenet_mean = np.array([0.485, 0.456, 0.406])\n", | ||
"imagenet_std = np.array([0.229, 0.224, 0.225])\n", | ||
"\n", | ||
"def show_image(image, title=''):\n", | ||
" # image is [H, W, 3]\n", | ||
" assert image.shape[2] == 3\n", | ||
" image = torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int()\n", | ||
" plt.imshow(image)\n", | ||
" plt.title(title, fontsize=16)\n", | ||
" plt.axis('off')\n", | ||
" return\n", | ||
"\n", | ||
"\n", | ||
"def show_images(x, im_masked, y, im_paste):\n", | ||
" # make the plt figure larger\n", | ||
" plt.rcParams['figure.figsize'] = [24, 6]\n", | ||
"\n", | ||
" plt.subplot(1, 4, 1)\n", | ||
" show_image(x, \"original\")\n", | ||
"\n", | ||
" plt.subplot(1, 4, 2)\n", | ||
" show_image(im_masked, \"masked\")\n", | ||
"\n", | ||
" plt.subplot(1, 4, 3)\n", | ||
" show_image(y, \"reconstruction\")\n", | ||
"\n", | ||
" plt.subplot(1, 4, 4)\n", | ||
" show_image(im_paste, \"reconstruction + visible\")\n", | ||
"\n", | ||
" plt.show()\n", | ||
"\n", | ||
"\n", | ||
"def post_process(x, y, mask):\n", | ||
" x = torch.einsum('nchw->nhwc', x.cpu())\n", | ||
" # masked image\n", | ||
" im_masked = x * (1 - mask)\n", | ||
" # MAE reconstruction pasted with visible patches\n", | ||
" im_paste = x * (1 - mask) + y * mask\n", | ||
" return x[0], im_masked[0], y[0], im_paste[0]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Load an image" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 16, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# load an image\n", | ||
"img_url = 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'\n", | ||
"img_pil = Image.open(requests.get(img_url, stream=True).raw)\n", | ||
"img = img_pil.resize((224, 224))\n", | ||
"img = np.array(img) / 255.\n", | ||
"\n", | ||
"assert img.shape == (224, 224, 3)\n", | ||
"\n", | ||
"# normalize by ImageNet mean and std\n", | ||
"img = img - imagenet_mean\n", | ||
"img = img / imagenet_std\n", | ||
"\n", | ||
"plt.rcParams['figure.figsize'] = [5, 5]\n", | ||
"show_image(torch.tensor(img))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Load a pre-trained MAE model" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 15, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"%%writefile ../configs/selfsup/mae/mae_visualization.py\n", | ||
"model = dict(\n", | ||
" type='MAE',\n", | ||
" backbone=dict(type='MAEViT', arch='l', patch_size=16, mask_ratio=0.75),\n", | ||
" neck=dict(\n", | ||
" type='MAEPretrainDecoder',\n", | ||
" patch_size=16,\n", | ||
" in_chans=3,\n", | ||
" embed_dim=1024,\n", | ||
" decoder_embed_dim=512,\n", | ||
" decoder_depth=8,\n", | ||
" decoder_num_heads=16,\n", | ||
" mlp_ratio=4.,\n", | ||
" ),\n", | ||
" head=dict(type='MAEPretrainHead', norm_pix=True, patch_size=16))\n", | ||
"\n", | ||
"img_norm_cfg = dict(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n", | ||
"# dataset summary\n", | ||
"data = dict(\n", | ||
" test=dict(\n", | ||
" pipeline = [\n", | ||
" dict(type='Resize', size=(224, 224)),\n", | ||
" dict(type='ToTensor'),\n", | ||
" dict(type='Normalize', **img_norm_cfg),]\n", | ||
" ))" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 13, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# This is an MAE model trained with pixels as targets for visualization (ViT-large, training mask ratio=0.75)\n", | ||
"\n", | ||
"# download checkpoint if not exist\n", | ||
"# This ckpt is converted from https://dl.fbaipublicfiles.com/mae/visualize/mae_visualize_vit_large.pth\n", | ||
"!wget -nc https://download.openmmlab.com/mmselfsup/mae/mae_visualize_vit_large.pth" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 14, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from mmselfsup.apis import init_model\n", | ||
"ckpt_path = \"mae_visualize_vit_large.pth\"\n", | ||
"model = init_model('../configs/selfsup/mae/mae_visualization.py', ckpt_path, device='cpu')\n", | ||
"print('Model loaded.')" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"### Run MAE on the image" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 12, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# make random mask reproducible (comment out to make it change)\n", | ||
"torch.manual_seed(2)\n", | ||
"print('MAE with pixel reconstruction:')\n", | ||
"\n", | ||
"from mmselfsup.apis import inference_model\n", | ||
"\n", | ||
"img_url = 'https://download.openmmlab.com/mmselfsup/mae/fox.jpg'\n", | ||
"img = Image.open(requests.get(img_url, stream=True).raw)\n", | ||
"img, (mask, pred) = inference_model(model, img)\n", | ||
"x, im_masked, y, im_paste = post_process(img, pred, mask)\n", | ||
"show_images(x, im_masked, y, im_paste)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.13" | ||
}, | ||
"vscode": { | ||
"interpreter": { | ||
"hash": "3e4aeeccd14e965f43d0896afbaf8d71604e66b8605affbaa33ec76aa4083757" | ||
} | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,8 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .inference import inference_model, init_model | ||
from .train import init_random_seed, set_random_seed, train_model | ||
|
||
__all__ = ['init_random_seed', 'set_random_seed', 'train_model'] | ||
__all__ = [ | ||
'init_random_seed', 'inference_model', 'set_random_seed', 'train_model', | ||
'init_model' | ||
] |
Oops, something went wrong.