Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Pekka committed Apr 16, 2017
0 parents commit 88ddeca
Show file tree
Hide file tree
Showing 11 changed files with 709 additions and 0 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.idea*
.DS_store
*.pyc
42 changes: 42 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
![](https://media.giphy.com/media/3og0IEKu84Ros9izyU/giphy.gif)

This project implements the DQN reinforcement learning agent from
[Human-level control through deep reinforcement
learning](http://www.davidqiu.com:8888/research/nature14236.pdf)

(See also David Silvers RL course [lecture 7](https://www.youtube.com/watch?v=UoPei5o4fps))

The agent is applied to the Open AI gym's [2d-car-racing environment](https://gym.openai.com/envs/CarRacing-v0)

The structure of the q-network differs from the original paper.
In particular, the network here is much smaller and can be easily trained without GPU.
(It's easy to specify any other structures as well)

The agent learns to drive the car from pixels in a few hours and doesn't need any hand-crafted features.
There are some minor environment specific tweaks for the car-racing but the base-agent doesn't know about car-racing.

#### pre-trained agent
The checkpoint provided in the repo used the default parameters
specified in the runner/agent and 150000~ playing steps for learning.

The training took about 5h with CPU.
This agent is playing in the above gif and in this video:
https://youtu.be/CVZQOAlQib0

The agent sometimes cuts corners but other than that it can drive flawlessly for minutes.
There are some occasional mistakes though.

#### Running instuctions
Just clone the repo and use car_runner.py
The settings are specified in the beginning of the runner.

You can either train from scratch or load the existing checkpoint
from this repo and see the agent driving somewhat properly right away.
Or you can train the provided checkpoint more.

#### Dependencies
- Python 3.5 (will not work with python 2)
- OpenAI Gym (the car-racing environment)
- Tensorflow 1.0.0
- numpy
- scikit-image
142 changes: 142 additions & 0 deletions car_runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# get_ipython().magic(u'load_ext autoreload')
# get_ipython().magic(u'autoreload 2')

"""
Python 3.5, tensorflow 1.0.0
Trains first for train_episodes amount of episodes
and then starts playing with the best known policy with no exploration
Optionally save checkpoints to checkpoint_path (set checkpoint_path=None to not save anything)
Experience history is never saved
Training and playing can be early stopped by giving input (pressing enter in console)
"""

from dqn.agent import CarRacingDQN
import os
import tensorflow as tf
import gym
import _thread
import re
import sys

# SETTINGS

# to start training from scratch:
load_checkpoint = False
checkpoint_path = "data/checkpoint02"
train_episodes = float("inf")
save_freq_episodes = 400

# To play from existing checkpoint without any training:
# load_checkpoint = True
# checkpoint_path = "data/checkpoint01"
# train_episodes = 0 #or just give higher value to train the existing checkpoint more

model_config = dict(
min_epsilon=0.1,
max_negative_rewards=12,
min_experience_size=int(1e3),
num_frame_stack=3,
frame_skip=3,
train_freq=4,
batchsize=64,
epsilon_decay_steps=int(1e5),
network_update_freq=int(1e3),
experience_capacity=int(4e4),
gamma=0.95
)

print(model_config)
########

env_name = "CarRacing-v0"
env = gym.make(env_name)

# tf.reset_default_graph()
dqn_agent = CarRacingDQN(env=env, **model_config)
dqn_agent.build_graph()
sess = tf.InteractiveSession()
dqn_agent.session = sess

saver = tf.train.Saver(max_to_keep=100)

if load_checkpoint:
print("loading the latest checkpoint from %s" % checkpoint_path)
ckpt = tf.train.get_checkpoint_state(checkpoint_path)
assert ckpt, "checkpoint path %s not found" % checkpoint_path
global_counter = int(re.findall("-(\d+)$", ckpt.model_checkpoint_path)[0])
saver.restore(sess, ckpt.model_checkpoint_path)
dqn_agent.global_counter = global_counter
else:
if checkpoint_path is not None:
assert not os.path.exists(checkpoint_path), \
"checkpoint path already exists but load_checkpoint is false"

tf.global_variables_initializer().run()


def save_checkpoint():
if not os.path.exists(checkpoint_path):
os.makedirs(checkpoint_path)
p = os.path.join(checkpoint_path, "m.ckpt")
saver.save(sess, p, dqn_agent.global_counter)
print("saved to %s - %d" % (p, dqn_agent.global_counter))


def one_episode():
reward, frames = dqn_agent.play_episode()
print("episode: %d, reward: %f, length: %d, total steps: %d" %
(dqn_agent.episode_counter, reward, frames, dqn_agent.global_counter))

save_cond = (
dqn_agent.episode_counter % save_freq_episodes == 0
and checkpoint_path is not None
and dqn_agent.do_training
)
if save_cond:
save_checkpoint()


def input_thread(list):
input("...enter to stop after current episode\n")
list.append("OK")


def main_loop():
"""
This just calls training function
as long as we get input to stop
"""
list = []
_thread.start_new_thread(input_thread, (list,))
while True:
if list:
break
if dqn_agent.do_training and dqn_agent.episode_counter > train_episodes:
break
one_episode()

print("done")


if train_episodes > 0:
print("now training... you can early stop with enter...")
print("##########")
sys.stdout.flush()
main_loop()
save_checkpoint()
print("ok training done")

sys.stdout.flush()

dqn_agent.max_neg_rewards = 100
dqn_agent.do_training = False

print("now just playing...")
print("##########")
sys.stdout.flush()
main_loop()

print("That's it. Good bye")
2 changes: 2 additions & 0 deletions data/checkpoint01/checkpoint
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
model_checkpoint_path: "m.ckpt-152156"
all_model_checkpoint_paths: "m.ckpt-152156"
Binary file not shown.
Binary file added data/checkpoint01/m.ckpt-152156.index
Binary file not shown.
Binary file added data/checkpoint01/m.ckpt-152156.meta
Binary file not shown.
Empty file added dqn/__init__.py
Empty file.
Loading

0 comments on commit 88ddeca

Please sign in to comment.