Skip to content

Commit

Permalink
clean code for ELo-SACv3
Browse files Browse the repository at this point in the history
  • Loading branch information
LostXine committed Oct 13, 2022
1 parent cf79b98 commit 2497bec
Show file tree
Hide file tree
Showing 45 changed files with 205 additions and 124 deletions.
14 changes: 7 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# ELo-SAC: Evolving Losses + Soft Actor Critic

This repository is the official implementation of ELo-SACv2 as a part of our paper **Does Self-supervised Learning Really Improve Reinforcement Learning from Pixels?** ([openreview](https://openreview.net/forum?id=fVslVNBfjd8), [arxiv](https://arxiv.org/abs/2206.05266)) at NeurIPS 2022.
This repository is the official implementation of ELo-SACv3 as a part of our paper **Does Self-supervised Learning Really Improve Reinforcement Learning from Pixels?** ([openreview](https://openreview.net/forum?id=fVslVNBfjd8), [arxiv](https://arxiv.org/abs/2206.05266)) at NeurIPS 2022.

Our implementation is based on [SAC+AE](https://github.com/denisyarats/pytorch_sac_ae) by Denis Yarats and [CURL](https://github.com/MishaLaskin/curl) by Michael Laskin.

You may also want to check ELo-SACv3 at the main branch of this repository, and Atari experiments were done in a separate codebase (Check [ELo-Rainbow](https://github.com/LostXine/elo-rainbow)).
You may also want to check ELo-SACv2 at branch v2 of this repository, and Atari experiments were done in a separate codebase (Check [ELo-Rainbow](https://github.com/LostXine/elo-rainbow)).

## Installation

Expand All @@ -14,7 +14,7 @@ All of the dependencies are in the `conda_env.yml` file. They can be installed m
conda env create -f conda_env.yml
```

Change the server IP and port in the `search-server-ip` file if necessary.
Change the server IP and port in the `server-addr` file if necessary.

## Instructions

Expand All @@ -27,7 +27,7 @@ python3 search-server.py --port 61888 --timeout 24
which will start a HTTP server listening at the given port.
The server runs a PSO (Particle Swarm Optimization) algorithm and distributes tasks to the clients with GPUs.
Timeout means how many hours the server will wait for the client to report results before it assigns the same task to another client.
Our optimization status is stored at `save_server/evol_rl.npy` and will be automatically loaded.
Our optimization status is stored at `server/evolve.npy` and will be automatically loaded.
One could start a new search by assigning `--path` to a new file.

To start the parameter search on clients, run `bash search.sh`.
Expand All @@ -36,14 +36,14 @@ When the training completes, the client will report the evaluation results to th

Run `bash check_status.sh` or `bash check_full_status.sh` to check the search status.

To stop the search, **stop** the current server and **restart** the search server with `--stop True` (see the `server-stop.sh` file).
To stop the search, **stop** the current server and **restart** the search server with `--stop True`.
All the clients will stop searching after finishing the current search.

To evaluate the optimal combination, run `bash eval-s09.sh` and it will start to train ELo-SAC agents in 6 DMControl environments with 10 random seeds.

See the `train.py` file for training hyper-parameters.
Check `train.py` for hyper-parameters during the training. The optimal parameters reported in the paper is stored at `server/top8.json`.

## Contact

1. Issue
2. email: xiangli8@cs.stonybrook.edu
2. email: xiangli8@cs.stonybrook.edu
2 changes: 1 addition & 1 deletion check_status.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash
url=`cat search-server-ip`
url=`cat server-addr`
curl -H 'Content-Type: application/json' $url"get_status"
37 changes: 32 additions & 5 deletions search-client.py → client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,45 @@
import time
import numpy as np
import traceback
import os

with open('search-server-ip', 'r') as f:
url = f.read().strip()
with open('server-addr', 'r') as f:
url = f.readline().strip()
print("URL: " + url)
p_data = None


def check_should_run():
return not os.path.exists('stop')

if __name__ == '__main__':
envs = [
('cartpole', 'swingup', 8),
('cheetah', 'run', 4),
('reacher', 'easy', 4),
('ball_in_cup', 'catch', 4),
('finger', 'spin', 2),
('walker', 'walk', 2)
]

torch.multiprocessing.set_start_method('spawn')
is_running = True
is_running = check_should_run()

while is_running:
while True:
is_running = check_should_run()
while is_running:
try:
r = requests.get(url + 'get_task')
if r.status_code == 200:
break
time.sleep(10)
is_running = check_should_run()
except ConnectionRefusedError:
print("Connection Refused Error, retry in 10s.")

if not is_running:
break

p_data = json.loads(r.text)
print(p_data)
if 'stop' in p_data:
Expand All @@ -37,7 +57,14 @@
print(f"Aug: {aug} | RL_Aug: {rl_aug} | Weights: ", weights)

try:
results = [main(weights, tid=p_data['tid'], aug=aug, rl_aug=rl_aug, seed=i, task_str=r.text) for i in p_data['seed']]
results = {}
for env in envs:
domain, task, action_repeat = env
results[domain + ' ' + task] = [main(weights,
domain=domain,
task=task,
action_repeat=action_repeat,
tid=p_data['tid'], aug=aug, rl_aug=rl_aug, seed=i, task_str=r.text) for i in p_data['seed']]
p_data['results'] = results
p_data['status'] = 0
except KeyboardInterrupt:
Expand Down
5 changes: 3 additions & 2 deletions conda_env.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,10 @@ dependencies:
- pip:
- numpy
- termcolor
- git+https://github.com/deepmind/dm_control.git
- gym==0.25.1
- dm_control
- git+https://github.com/1nadequacy/dmc2gym.git
- tensorboard
- tb-nightly
- imageio
- imageio-ffmpeg
- scikit-image
Expand Down
7 changes: 7 additions & 0 deletions envs.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@

EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name cartpole --task_name swingup --action_repeat 8 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 12500 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name cheetah --task_name run --action_repeat 4 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 25000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name reacher --task_name easy --action_repeat 4 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 25000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name ball_in_cup --task_name catch --action_repeat 4 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 25000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name finger --task_name spin --action_repeat 2 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 50000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name walker --task_name walk --action_repeat 2 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 50000 --save_tb --eval_freq 12500
14 changes: 8 additions & 6 deletions eval-backbone512.sh
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@

EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name cartpole --task_name swingup --action_repeat 8 --pre_transform_image_size 116 --work_dir ./save_eval --seed $2 --batch_size 512 --num_train_steps 12500 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name cheetah --task_name run --action_repeat 4 --pre_transform_image_size 116 --work_dir ./save_eval --seed $2 --batch_size 512 --num_train_steps 25000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name reacher --task_name easy --action_repeat 4 --pre_transform_image_size 116 --work_dir ./save_eval --seed $2 --batch_size 512 --num_train_steps 25000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name ball_in_cup --task_name catch --action_repeat 4 --pre_transform_image_size 116 --work_dir ./save_eval --seed $2 --batch_size 512 --num_train_steps 25000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name finger --task_name spin --action_repeat 2 --pre_transform_image_size 116 --work_dir ./save_eval --seed $2 --batch_size 512 --num_train_steps 50000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name walker --task_name walk --action_repeat 2 --pre_transform_image_size 116 --work_dir ./save_eval --seed $2 --batch_size 512 --num_train_steps 50000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name cartpole --task_name swingup --action_repeat 8 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 12500 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name cheetah --task_name run --action_repeat 4 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 25000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name reacher --task_name easy --action_repeat 4 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 25000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name ball_in_cup --task_name catch --action_repeat 4 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 25000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name finger --task_name spin --action_repeat 2 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 50000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name walker --task_name walk --action_repeat 2 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 50000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name hopper --task_name hop --action_repeat 2 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 50000 --save_tb --eval_freq 12500
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python eval-client.py --domain_name reacher --task_name hard --action_repeat 2 --pre_transform_image_size 116 --work_dir ./eval --seed $2 --batch_size 512 --num_train_steps 50000 --save_tb --eval_freq 12500
42 changes: 17 additions & 25 deletions eval-client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,24 @@
import traceback


with open('search-server-ip', 'r') as f:
url = f.read().strip()
p_data = None

if __name__ == '__main__':
torch.multiprocessing.set_start_method('spawn')
while True:
try:
r = requests.get(url + 'get_eval')
if r.status_code == 200:
break
time.sleep(10)
except ConnectionRefusedError:
print("Connection Refused Error, retry in 10s.")

p_data = json.loads(r.text)
print(p_data)

data = p_data['data']
aug = data[0]
rl_aug = data[1]
weights = data[2:]
print(f"Aug: {aug} | RL_Aug: {rl_aug} | Weights: ", weights)

main(weights, aug=aug, rl_aug=rl_aug, task_str=r.text)

print("Done")

for i in range(1, 21):
js_file = f'server/top{i}.json'
print(js_file)
with open(js_file, 'r') as f:
p_data = json.load(f)
print(p_data)

data = p_data['data']
aug = data[0]
rl_aug = data[1]
weights = data[2:]
print(f"Aug: {aug} | RL_Aug: {rl_aug} | Weights: ", weights)

main(weights, aug=aug, rl_aug=rl_aug, task_str=json.dumps(p_data), tid=i)

print(f"{js_file} done!")


File renamed without changes.
2 changes: 1 addition & 1 deletion check_full_status.sh → full_status.sh
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
#!/bin/bash
url=`cat search-server-ip`
url=`cat server-addr`
curl -H 'Content-Type: application/json' $url"get_full_status"
Empty file removed save_server/.gitkeep
Empty file.
Binary file removed save_server/evol_rl.npy
Binary file not shown.
Empty file removed save_test/.gitkeep
Empty file.
1 change: 0 additions & 1 deletion search-server-ip

This file was deleted.

2 changes: 1 addition & 1 deletion search.sh
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,5 @@ fi

echo "Use GPU: $1"

EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python3 search-client.py --domain_name cheetah --task_name run --num_train_steps 25000 --pre_transform_image_size 116 --work_dir ./save_search --batch_size 128 --save_model
EGL_DEVICE_ID=$1 CUDA_VISIBLE_DEVICES=$1 python3 client.py --num_train_steps 25000 --pre_transform_image_size 116 --work_dir ../main_results/search --batch_size 64

1 change: 1 addition & 0 deletions server-addr
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
http://127.0.0.1:61888/
2 changes: 0 additions & 2 deletions server-stop.sh

This file was deleted.

Loading

0 comments on commit 2497bec

Please sign in to comment.