Skip to content

Commit

Permalink
new core added
Browse files Browse the repository at this point in the history
  • Loading branch information
carloderamo committed Aug 1, 2018
1 parent f0f7ab0 commit f4f5cc3
Show file tree
Hide file tree
Showing 5 changed files with 158 additions and 137 deletions.
79 changes: 0 additions & 79 deletions atari.py

This file was deleted.

104 changes: 104 additions & 0 deletions core.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from tqdm import tqdm

import numpy as np


class Core(object):
def __init__(self, agent, mdp, callbacks=None):
self.agent = agent
self.mdp = mdp
self._n_mdp = len(self.mdp)
self.callbacks = callbacks if callbacks is not None else list()

self._state = [None for _ in range(self._n_mdp)]

self._total_steps_counter = 0
self._current_steps_counter = 0
self._episode_steps = [None for _ in range(self._n_mdp)]
self._n_steps_per_fit = None

def learn(self, n_steps=None, n_steps_per_fit=None, render=False,
quiet=False):
self._n_steps_per_fit = n_steps_per_fit

fit_condition = \
lambda: self._current_steps_counter >= self._n_steps_per_fit

self._run(n_steps, fit_condition, render, quiet)

def evaluate(self, n_steps=None, render=False,
quiet=False):
fit_condition = lambda: False

return self._run(n_steps, fit_condition, render, quiet)

def _run(self, n_steps, fit_condition, render, quiet):
move_condition = lambda: self._total_steps_counter < n_steps

steps_progress_bar = tqdm(total=n_steps,
dynamic_ncols=True, disable=quiet,
leave=False)

return self._run_impl(move_condition, fit_condition, steps_progress_bar,
render)

def _run_impl(self, move_condition, fit_condition, steps_progress_bar,
render):
self._total_steps_counter = 0
self._current_steps_counter = 0

dataset = list()
last = [True] * self._n_mdp
while move_condition():
for i in range(self._n_mdp):
if last[i]:
self.reset(i)

sample = self._step(i, render)
dataset.append(sample)

last[i] = sample[-1]

self._total_steps_counter += 1
self._current_steps_counter += 1
steps_progress_bar.update(1)

if fit_condition():
self.agent.fit(dataset)
self._current_episodes_counter = 0
self._current_steps_counter = 0

for c in self.callbacks:
callback_pars = dict(dataset=dataset)
c(**callback_pars)

dataset = list()

self.agent.stop()
for i in range(self._n_mdp):
self.mdp[i].stop()

return dataset

def _step(self, i, render):
action = self.agent.draw_action([i, self._state[i]])
next_state, reward, absorbing, _ = self.mdp[i].step(action)

self._episode_steps[i] += 1

if render:
self.mdp[i].render()

last = not(
self._episode_steps[i] < self.mdp[i].info.horizon and not absorbing)

state = self._state[i]
self._state[i] = np.array(next_state) # Copy for safety reasons

return [i, state], action, reward, [i, next_state], absorbing, last

def reset(self, i):
self._state[i] = self.mdp[i].reset()
self.agent.episode_start(i)
self.agent.next_action = None
self._episode_steps[i] = 0
8 changes: 4 additions & 4 deletions policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __call__(self, *args):
return probs

def draw_action(self, state):
idx = np.asscalar(state[0])
idx = state[0]
state = state[1]
if not np.random.uniform() < self._epsilons[idx](state):
q = self._approximator.predict(
Expand All @@ -65,7 +65,7 @@ def set_epsilon(self, epsilon):
self._epsilons[i] = epsilon

def update(self, state):
idx = np.asscalar(state[0])
idx = state[0]
self._epsilons[idx].update(state)


Expand Down Expand Up @@ -108,7 +108,7 @@ def __call__(self, *args):
return probs

def draw_action(self, state):
idx = np.asscalar(state[0])
idx = state[0]
state = state[1]
if not np.random.uniform() < self._epsilons[idx](state):
q = self._approximator[idx].predict(state)
Expand All @@ -132,5 +132,5 @@ def set_epsilon(self, epsilon):
self._epsilons[i] = epsilon

def update(self, state):
idx = np.asscalar(state[0])
idx = state[0]
self._epsilons[idx].update(state)
23 changes: 12 additions & 11 deletions shared/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,15 +39,15 @@ def __init__(self, approximator, policy, mdp_info, batch_size,
self._entropy_coeff = entropy_coeff

self._replay_memory = [
ReplayMemory(mdp_info, initial_replay_size, max_replay_size,
history_length, dtype) for _ in range(self._n_games)
ReplayMemory(mdp_info[i], initial_replay_size, max_replay_size,
history_length, dtype) for i in range(self._n_games)
]
self._buffer = [
Buffer(history_length, dtype) for _ in range(self._n_games)
]

self._n_updates = 0
self._episode_steps = 0
self._episode_steps = [0 for _ in range(self._n_games)]
self._no_op_actions = None

apprx_params_train = deepcopy(approximator_params)
Expand All @@ -60,7 +60,7 @@ def __init__(self, approximator, policy, mdp_info, batch_size,
self.target_approximator.model.set_weights(
self.approximator.model.get_weights())

super().__init__(policy, mdp_info)
super().__init__(policy, mdp_info[np.argmax(self._n_action_per_head)])

n_samples = self._batch_size * self._n_games
self._state_idxs = np.zeros(n_samples, dtype=np.int)
Expand Down Expand Up @@ -183,25 +183,26 @@ def _next_q(self):
return out_q

def draw_action(self, state):
self._buffer[np.asscalar(state[0])].add(state[1])
idx = state[0]
self._buffer[idx].add(state[1])

if self._episode_steps < self._no_op_actions:
if self._episode_steps[idx] < self._no_op_actions:
action = np.array([self._no_op_action_value])
self.policy.update(state)
else:
extended_state = self._buffer[np.asscalar(state[0])].get()
extended_state = self._buffer[idx].get()

extended_state = np.array([state[0], extended_state])
extended_state = [idx, np.array([extended_state])]
action = super(DQN, self).draw_action(extended_state)

self._episode_steps += 1
self._episode_steps[idx] += 1

return action

def episode_start(self):
def episode_start(self, idx):
if self._max_no_op_actions == 0:
self._no_op_actions = 0
else:
self._no_op_actions = np.random.randint(
self._history_length, self._max_no_op_actions + 1)
self._episode_steps = 0
self._episode_steps[idx] = 0
Loading

0 comments on commit f4f5cc3

Please sign in to comment.