diff --git a/core.py b/core.py index 7a05e8c..85aaa99 100644 --- a/core.py +++ b/core.py @@ -17,8 +17,6 @@ def __init__(self, agent, mdp, callbacks=None): self._episode_steps = [None for _ in range(self._n_mdp)] self._n_steps_per_fit = None - self.prioritized = False - def learn(self, n_steps=None, n_steps_per_fit=None, render=False, quiet=False): self._n_steps_per_fit = n_steps_per_fit @@ -52,12 +50,7 @@ def _run_impl(self, move_condition, fit_condition, steps_progress_bar, dataset = list() last = [True] * self._n_mdp while move_condition(): - if self.prioritized: - p = self.agent.replay_memory_priorities - mdps = [np.random.choice(self._n_mdp, p=p)] - else: - mdps = np.arange(self._n_mdp) - for i in mdps: + for i in range(self._n_mdp): if last[i]: self.reset(i) diff --git a/dqn/dqn.py b/dqn/dqn.py index cac6a20..198880f 100644 --- a/dqn/dqn.py +++ b/dqn/dqn.py @@ -20,7 +20,7 @@ def __init__(self, approximator, policy, mdp_info, batch_size, history_length=4, n_input_per_mdp=None, replay_memory=None, target_update_frequency=2500, fit_params=None, approximator_params=None, n_games=1, clip_reward=True, - sampling_prior=False, dtype=np.uint8): + dtype=np.uint8): self._fit_params = dict() if fit_params is None else fit_params self._batch_size = batch_size @@ -80,9 +80,6 @@ def __init__(self, approximator, policy, mdp_info, batch_size, self._absorbing = np.zeros(n_samples) self._idxs = np.zeros(n_samples, dtype=np.int) self._is_weight = np.zeros(n_samples) - self.replay_memory_priorities = np.ones(self._n_games) / self._n_games - - self._sampling_prior = sampling_prior def fit(self, dataset): self._fit(dataset) @@ -129,19 +126,6 @@ def _fit_standard(self, dataset): q_next = self._next_q() q = reward + q_next - if self._sampling_prior: - grads = np.zeros(self._n_games) - for i in range(self._n_games): - idxs = np.argwhere(self._state_idxs == i).ravel() - self.approximator.fit( - self._state[idxs], self._action[idxs], q[idxs], - idx=self._state_idxs[idxs], - params=self.approximator.model.network.get_shared_weights_tensor(), - **self._fit_params - ) - grads[i] = self.approximator.model.grad - self.replay_memory_priorities = grads / grads.sum() - self.approximator.fit(self._state, self._action, q, idx=self._state_idxs, **self._fit_params) @@ -190,20 +174,6 @@ def _fit_prioritized(self, dataset): idx=self._state_idxs) td_error = q - q_current - if self._sampling_prior: - grads = np.zeros(self._n_games) - for i in range(self._n_games): - idxs = np.argwhere(self._state_idxs == i).ravel() - self.approximator.fit( - self._state[idxs], self._action[idxs], q[idxs], - weights=self._is_weight[idxs], - idx=self._state_idxs[idxs], - params=self.approximator.model.network.get_shared_weights_tensor(), - **self._fit_params - ) - grads[i] = self.approximator.model.grad - self.replay_memory_priorities = grads / grads.sum() - for er in self._replay_memory: er.update(td_error, self._idxs) diff --git a/dqn/run_gym.py b/dqn/run_gym.py index 14ec2ca..907c33d 100644 --- a/dqn/run_gym.py +++ b/dqn/run_gym.py @@ -157,7 +157,7 @@ def experiment(args, idx): approximator = CustomPyTorchApproximator - if args.er_prior: + if args.prioritized: replay_memory = [PrioritizedReplayMemory( initial_replay_size, max_replay_size, alpha=.6, beta=LinearParameter(.4, threshold_value=1, @@ -178,7 +178,6 @@ def experiment(args, idx): n_actions_per_head=n_actions_per_head, clip_reward=False, history_length=args.history_length, - sampling_prior=args.sampling_prior, dtype=np.float32 ) @@ -231,10 +230,8 @@ def experiment(args, idx): print('- Learning:') # learning step pi.set_parameter(None) - core.prioritized = args.sampling_prior core.learn(n_steps=evaluation_frequency, n_steps_per_fit=train_frequency, quiet=args.quiet) - core.prioritized = False print('- Evaluation:') # evaluation step @@ -287,10 +284,8 @@ def experiment(args, idx): help='Initial size of the replay memory.') arg_mem.add_argument("--max-replay-size", type=int, default=5000, help='Max size of the replay memory.') - arg_mem.add_argument("--er-prior", action='store_true', + arg_mem.add_argument("--prioritized", action='store_true', help='Whether to use prioritized memory or not.') - arg_mem.add_argument("--sampling-prior", action='store_true', - help='Whether to use prioritized sampling or not.') arg_net = parser.add_argument_group('Deep Q-Network') arg_net.add_argument("--optimizer",