Skip to content

Commit

Permalink
a
Browse files Browse the repository at this point in the history
  • Loading branch information
carloderamo committed Jul 19, 2019
1 parent c3505aa commit 022e349
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 35 deletions.
9 changes: 8 additions & 1 deletion core.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(self, agent, mdp, callbacks=None):
self._episode_steps = [None for _ in range(self._n_mdp)]
self._n_steps_per_fit = None

self.prioritized = False

def learn(self, n_steps=None, n_steps_per_fit=None, render=False,
quiet=False):
self._n_steps_per_fit = n_steps_per_fit
Expand Down Expand Up @@ -50,7 +52,12 @@ def _run_impl(self, move_condition, fit_condition, steps_progress_bar,
dataset = list()
last = [True] * self._n_mdp
while move_condition():
for i in range(self._n_mdp):
if self.prioritized:
p = self.agent.replay_memory_priorities
mdps = [np.random.choice(self._n_mdp, p=p)]
else:
mdps = np.arange(self._n_mdp)
for i in mdps:
if last[i]:
self.reset(i)

Expand Down
68 changes: 36 additions & 32 deletions dqn/dqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(self, approximator, policy, mdp_info, batch_size,
history_length=4, n_input_per_mdp=None, replay_memory=None,
target_update_frequency=2500, fit_params=None,
approximator_params=None, n_games=1, clip_reward=True,
dtype=np.uint8):
sampling_prior=False, dtype=np.uint8):
self._fit_params = dict() if fit_params is None else fit_params

self._batch_size = batch_size
Expand Down Expand Up @@ -80,7 +80,9 @@ def __init__(self, approximator, policy, mdp_info, batch_size,
self._absorbing = np.zeros(n_samples)
self._idxs = np.zeros(n_samples, dtype=np.int)
self._is_weight = np.zeros(n_samples)
self._replay_memory_priorities = np.ones(self._n_games) / self._n_games
self.replay_memory_priorities = np.ones(self._n_games) / self._n_games

self._sampling_prior = sampling_prior

def fit(self, dataset):
self._fit(dataset)
Expand Down Expand Up @@ -127,6 +129,19 @@ def _fit_standard(self, dataset):
q_next = self._next_q()
q = reward + q_next

if self._sampling_prior:
for i in range(self._n_games):
idxs = np.argwhere(self._state_idxs == i).ravel()
self.approximator.fit(
self._state[idxs], self._action[idxs], q[idxs],
idx=self._state_idxs[idxs],
er_idx=i,
params=self.approximator.model.network.get_shared_weights_tensor(),
**self._fit_params
)
grads = self.approximator.model.grads
self.replay_memory_priorities = grads / grads.sum()

self.approximator.fit(self._state, self._action, q,
idx=self._state_idxs, **self._fit_params)

Expand All @@ -146,36 +161,24 @@ def _fit_prioritized(self, dataset):
fit_condition = np.all([rm.initialized for rm in self._replay_memory])

if fit_condition:
replay_memory_idxs = np.random.choice(
self._n_games, size=self._n_games * self._batch_size,
p=self._replay_memory_priorities
)
counts = np.zeros(self._n_games, dtype=np.int)
idx, c = np.unique(replay_memory_idxs, return_counts=True)
counts[idx] = c
start = 0
for i, c in enumerate(counts):
if c == 0:
continue
for i in range(len(self._replay_memory)):
game_state, game_action, game_reward, game_next_state,\
game_absorbing, _, game_idxs, game_is_weight =\
self._replay_memory[i].get(counts[i])
self._replay_memory[i].get(self._batch_size)

stop = start + counts[i]
diff = stop - start
start = self._batch_size * i
stop = start + self._batch_size

self._state_idxs[start:stop] = np.ones(diff) * i
self._state_idxs[start:stop] = np.ones(self._batch_size) * i
self._state[start:stop, :self._n_input_per_mdp[i][0]] = game_state
self._action[start:stop] = game_action
self._reward[start:stop] = game_reward
self._next_state_idxs[start:stop] = np.ones(diff) * i
self._next_state_idxs[start:stop] = np.ones(self._batch_size) * i
self._next_state[start:stop, :self._n_input_per_mdp[i][0]] = game_next_state
self._absorbing[start:stop] = game_absorbing
self._idxs[start:stop] = game_idxs
self._is_weight[start:stop] = game_is_weight

start = stop

if self._clip_reward:
reward = np.clip(self._reward, -1, 1)
else:
Expand All @@ -187,18 +190,19 @@ def _fit_prioritized(self, dataset):
idx=self._state_idxs)
td_error = q - q_current

for i in range(self._n_games):
idxs = np.argwhere(self._state_idxs == i).ravel()
self.approximator.fit(
self._state[idxs], self._action[idxs], q[idxs],
weights=self._is_weight[idxs],
idx=self._state_idxs[idxs],
er_idx=i,
params=self.approximator.model.network.get_shared_weights_tensor(),
**self._fit_params
)
grads = self.approximator.model.grads
self._replay_memory_priorities = grads / grads.sum()
if self._sampling_prior:
for i in range(self._n_games):
idxs = np.argwhere(self._state_idxs == i).ravel()
self.approximator.fit(
self._state[idxs], self._action[idxs], q[idxs],
weights=self._is_weight[idxs],
idx=self._state_idxs[idxs],
er_idx=i,
params=self.approximator.model.network.get_shared_weights_tensor(),
**self._fit_params
)
grads = self.approximator.model.grads
self.replay_memory_priorities = grads / grads.sum()

for er in self._replay_memory:
er.update(td_error, self._idxs)
Expand Down
9 changes: 7 additions & 2 deletions dqn/run_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def experiment(args, idx):

approximator = CustomPyTorchApproximator

if args.prioritized:
if args.er_prior:
replay_memory = [PrioritizedReplayMemory(
initial_replay_size, max_replay_size, alpha=.6,
beta=LinearParameter(.4, threshold_value=1,
Expand All @@ -178,6 +178,7 @@ def experiment(args, idx):
n_actions_per_head=n_actions_per_head,
clip_reward=False,
history_length=args.history_length,
sampling_prior=args.sampling_prior,
dtype=np.float32
)

Expand Down Expand Up @@ -230,8 +231,10 @@ def experiment(args, idx):
print('- Learning:')
# learning step
pi.set_parameter(None)
core.prioritized = args.sampling_prior
core.learn(n_steps=evaluation_frequency,
n_steps_per_fit=train_frequency, quiet=args.quiet)
core.prioritized = False

print('- Evaluation:')
# evaluation step
Expand Down Expand Up @@ -284,8 +287,10 @@ def experiment(args, idx):
help='Initial size of the replay memory.')
arg_mem.add_argument("--max-replay-size", type=int, default=5000,
help='Max size of the replay memory.')
arg_mem.add_argument("--prioritized", action='store_true',
arg_mem.add_argument("--er-prior", action='store_true',
help='Whether to use prioritized memory or not.')
arg_mem.add_argument("--sampling-prior", action='store_true',
help='Whether to use prioritized sampling or not.')

arg_net = parser.add_argument_group('Deep Q-Network')
arg_net.add_argument("--optimizer",
Expand Down

0 comments on commit 022e349

Please sign in to comment.