Skip to content

Commit

Permalink
beta release - code cleaned
Browse files Browse the repository at this point in the history
  • Loading branch information
jamessealesmith committed Apr 13, 2023
1 parent d220ee9 commit 22948d2
Show file tree
Hide file tree
Showing 21 changed files with 62,415 additions and 4 deletions.
127 changes: 127 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# repo-specific stuff
data/
outputs/
*.pt
\#*#
.idea
*.sublime-*
*.pkl
.DS_Store
*.pth
*.png
.swp

# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# dotenv
.env

# virtualenv
.venv
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
© 2023 GitHub, Inc.
Terms
Privacy
Security
Status
Docs
Contact GitHub
Pricing
API
Training
Blog
About
Loading complete
38 changes: 34 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
## (Coming April 2023) CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning
PyTorch code for the CVPR 2023 paper (Coming April 2023):\
## CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning
PyTorch code for the CVPR 2023 paper:\
**CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning**\
**_[James Smith]_**, Leonid Karlinsky, Vyshnavi Gutta, Paola Cascante-Bonilla, Donghyun Kim, Assaf Arbelle, Rameswar Panda, Rogerio Feris, Zsolt Kira\
**_[James Smith]_**, *Leonid Karlinsky, Vyshnavi Gutta, Paola Cascante-Bonilla*\
*Donghyun Kim, Assaf Arbelle, Rameswar Panda, Rogerio Feris, Zsolt Kira* \
IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2023\
[[arXiv]]

Expand All @@ -12,8 +13,37 @@ IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR), 2023\
## Abstract
Computer vision models suffer from a phenomenon known as catastrophic forgetting when learning novel concepts from continuously shifting training data. Typical solutions for this continual learning problem require extensive rehearsal of previously seen data, which increases memory costs and may violate data privacy. Recently, the emergence of large-scale pre-trained vision transformer models has enabled prompting approaches as an alternative to data-rehearsal. These approaches rely on a key-query mechanism to generate prompts and have been found to be highly resistant to catastrophic forgetting in the well-established rehearsal-free continual learning setting. However, the key mechanism of these methods is not trained end-to-end with the task sequence. Our experiments show that this leads to a reduction in their plasticity, hence sacrificing new task accuracy, and inability to benefit from expanded parameter capacity. We instead propose to learn a set of prompt components which are assembled with input-conditioned weights to produce input-conditioned prompts, resulting in a novel attention-based end-to-end key-query scheme. Our experiments show that we outperform the current SOTA method DualPrompt on established benchmarks by as much as 5.4% in average accuracy. We also outperform the state of art by as much as 6.6% accuracy on a continual learning benchmark which contains both class-incremental and domain-incremental task shifts, corresponding to many practical settings.

## Setup
* Install anaconda: https://www.anaconda.com/distribution/
* set up conda environmet w/ python 3.8, ex: `conda create --name coda python=3.8`
* `conda activate coda`
* `sh install_requirements.sh`
* <b>NOTE: this framework was tested using `torch == 2.0.0` but should work for previous versions</b>

## Datasets
* Create a folder `data/`
* **CIFAR 100**: should automatically be downloaded
* **ImageNet-R**: *coming soon*!
* **DomainNet**: *coming soon*!

## Training
All commands should be run under the project root directory. **The scripts are set up for 4 GPUs** but can be modified for your hardware.

```bash
sh experiments/cifar100.sh
```

## Results
Results will be saved in a folder named `outputs/`. To get the final average accuracy, retrieve the final number in the file `outputs/**/results-acc/global.yaml`

## Ready to create your next method?
Create your new prompting method in `models/zoo.py`, which will require you to create a new class in `learners/prompt.py` as well. Hopefully, you can create your next method while only modifying these two files! I also reccomend you develop with the ImageNet-R benchmark (*coming soon*) and use fewer epochs for faster results. **Cannot wait to see what method you develop!**

## Acknowledgement
This material is based upon work supported by the National Science Foundation under Grant No. 2239292.

## Citation
If you found our work useful for your research, please cite our work:
**If you found our work useful for your research, please cite our work**:

@article{smith2022coda,
title={CODA-Prompt: COntinual Decomposed Attention-based Prompting for Rehearsal-Free Continual Learning},
Expand Down
19 changes: 19 additions & 0 deletions configs/cifar-100_prompt.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
dataset: CIFAR100
first_split_size: 10
other_split_size: 10
schedule:
- 20
schedule_type: cosine
batch_size: 128
optimizer: Adam
lr: 0.001
momentum: 0.9
weight_decay: 0
model_type: zoo
model_name: vit_pt_imnet
max_task: -1
dataroot: data
workers: 4
validation: False
train_aug: True
rand_split: True
5 changes: 5 additions & 0 deletions dataloaders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from __future__ import absolute_import

from .dataloader import iCIFAR100, iCIFAR10, iIMAGENET_R

__all__ = ('iCIFAR100','iCIFAR10','iIMAGENET_R')
Loading

0 comments on commit 22948d2

Please sign in to comment.