Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
symoon11 committed Oct 2, 2023
1 parent 3e322e9 commit 9514f88
Show file tree
Hide file tree
Showing 26 changed files with 2,329 additions and 1 deletion.
129 changes: 129 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/
40 changes: 39 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1 +1,39 @@
# Achievement-Distillation
# Achievement Distillation

This is the code for the paper [Discovering Hierarchical Achievements in Reinforcement Learning via Contrastive Learning](https://arxiv.org/abs/2307.03486) accepted to NeurIPS 2023.

## Installation

```
conda create --name ad-crafter python=3.10
conda activate ad-crafter
pip install --upgrade "setuptools==65.7.0" "wheel==0.38.4"
pip install -r requirements.txt
pip install -e .
```

## Usage

PPO (baseline)
```
python train.py --exp_name ppo --log_stats
```

PPO + Achievement Distillation (ours)
```
python train.py --exp_name ppo_ad --log_stats
```

## Citation

If you find this code useful, please cite this work.

```
@inproceedings{moon2023ad,
title={Discovering Hierarchical Achievements in Reinforcement Learning via Contrastive Learning},
author={Seungyong Moon and Junyoung Yeom and Bumsoo Park and Hyun Oh Song},
booktitle={Neural Information Processing Systems},
year={2023}
}
```
49 changes: 49 additions & 0 deletions achievement_distillation/action_head.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import torch as th
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


class CategoricalActionHead(nn.Module):
def __init__(
self,
insize: int,
num_actions: int,
init_scale: float = 0.01,
):
super().__init__()

# Layer
self.linear = nn.Linear(insize, num_actions)

# Initialization
init.orthogonal_(self.linear.weight, gain=init_scale)
init.constant_(self.linear.bias, val=0.0)

def forward(self, x: th.Tensor) -> th.Tensor:
x = self.linear(x)
logits = F.log_softmax(x, dim=-1)
return logits

def log_prob(self, logits: th.Tensor, actions: th.Tensor) -> th.Tensor:
log_prob = th.gather(logits, dim=-1, index=actions)
return log_prob

def entropy(self, logits: th.Tensor) -> th.Tensor:
probs = th.exp(logits)
entropy = -th.sum(probs * logits, dim=-1, keepdim=True)
return entropy

def sample(self, logits: th.Tensor, deterministic: bool = False) -> th.Tensor:
if deterministic:
actions = th.argmax(logits, dim=-1, keepdim=True)
else:
u = th.rand_like(logits)
u[u == 1.0] = 0.999
gumbels = logits - th.log(-th.log(u))
actions = th.argmax(gumbels, dim=-1, keepdim=True)
return actions

def kl_divergence(self, logits_q: th.Tensor, logits_p: th.Tensor) -> th.Tensor:
kl = th.sum(th.exp(logits_q) * (logits_q - logits_p), dim=-1, keepdim=True)
return kl
3 changes: 3 additions & 0 deletions achievement_distillation/algorithm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .base import BaseAlgorithm
from .ppo import PPOAlgorithm
from .ppo_ad import PPOADAlgorithm
16 changes: 16 additions & 0 deletions achievement_distillation/algorithm/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import abc
from typing import Dict

import torch as th

from achievement_distillation.model.base import BaseModel
from achievement_distillation.storage import RolloutStorage


class BaseAlgorithm(abc.ABC):
def __init__(self, model: BaseModel):
self.model = model

@abc.abstractclassmethod
def update(self, storage: RolloutStorage) -> Dict[str, th.Tensor]:
raise NotImplementedError
78 changes: 78 additions & 0 deletions achievement_distillation/algorithm/ppo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import torch.nn as nn
import torch.optim as optim

from achievement_distillation.model.ppo import PPOModel
from achievement_distillation.algorithm.base import BaseAlgorithm
from achievement_distillation.storage import RolloutStorage


class PPOAlgorithm(BaseAlgorithm):
def __init__(
self,
model: PPOModel,
ppo_nepoch: int,
ppo_nbatch: int,
clip_param: float,
vf_loss_coef: float,
ent_coef: float,
lr: float,
max_grad_norm: float,
):
super().__init__(model=model)

# PPO params
self.clip_param = clip_param
self.ppo_nepoch = ppo_nepoch
self.ppo_nbatch = ppo_nbatch
self.vf_loss_coef = vf_loss_coef
self.ent_coef = ent_coef
self.max_grad_norm = max_grad_norm

# optimizer
self.optimizer = optim.Adam(model.parameters(), lr=lr)

def update(self, storage: RolloutStorage):
# set model to training mode
self.model.train()

# run PPO
pi_loss_epoch = 0
vf_loss_epoch = 0
entropy_epoch = 0
nupdate = 0

for _ in range(self.ppo_nepoch):
# get data loader
data_loader = storage.get_data_loader(self.ppo_nbatch)

for batch in data_loader:
# compute loss
losses = self.model.compute_losses(**batch, clip_param=self.clip_param)
pi_loss = losses["pi_loss"]
vf_loss = losses["vf_loss"]
entropy = losses["entropy"]
loss = pi_loss + self.vf_loss_coef * vf_loss - self.ent_coef * entropy

# update parameter
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
self.optimizer.step()

# update stats
pi_loss_epoch += pi_loss.item()
vf_loss_epoch += vf_loss.item()
entropy_epoch += entropy.item()
nupdate += 1

# compute average training stats
pi_loss_epoch /= nupdate
vf_loss_epoch /= nupdate
entropy_epoch /= nupdate
train_stats = {
"pi_loss": pi_loss_epoch,
"vf_loss": vf_loss_epoch,
"entropy": entropy_epoch,
}

return train_stats
Loading

0 comments on commit 9514f88

Please sign in to comment.