Skip to content

Commit

Permalink
fixed alpha SAC discrete error
Browse files Browse the repository at this point in the history
  • Loading branch information
p-christ committed Jul 18, 2019
1 parent 77c47fa commit d4bdedf
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 7 deletions.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,7 @@ Random_Junkyard/
*to_do_list
Notebook.ipynb
Results/Notebook.ipynb
*.ipynb_checkpoints
*.ipynb_checkpoints
*.drive_access_key.json
drive_access_key.json
drive_access_key
11 changes: 9 additions & 2 deletions agents/Base_Agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,15 +134,22 @@ def log_game_info(self):

def set_random_seeds(self, random_seed):
"""Sets all possible random seeds so results can be reproduced"""
os.environ['PYTHONHASHSEED'] = str(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(random_seed)
tf.set_random_seed(random_seed)
random.seed(random_seed)
np.random.seed(random_seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(random_seed)
self.config.seed = random_seed
if torch.cuda.is_available():
torch.cuda.manual_seed_all(random_seed)
torch.cuda.manual_seed(random_seed)
if hasattr(gym.spaces, 'prng'):
gym.spaces.prng.seed(random_seed)

def reset_game(self):
"""Resets the game information so we are ready to play a new episode"""
self.environment.seed(self.config.seed)
self.state = self.environment.reset()
self.next_state = None
self.action = None
Expand Down
4 changes: 2 additions & 2 deletions agents/actor_critic_agents/SAC_Discrete.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(self, config):
lr=self.hyperparameters["Actor"]["learning_rate"])
self.automatic_entropy_tuning = self.hyperparameters["automatically_tune_entropy_hyperparameter"]
if self.automatic_entropy_tuning:
self.target_entropy = -torch.prod(torch.Tensor(self.environment.action_space.shape).to(self.device)).item() # heuristic value from the paper
self.target_entropy = - self.environment.unwrapped.action_space.n / 4.0 # heuristic value from the paper
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
self.alpha = self.log_alpha.exp()
self.alpha_optim = Adam([self.log_alpha], lr=self.hyperparameters["Actor"]["learning_rate"])
Expand Down Expand Up @@ -80,7 +80,7 @@ def calculate_actor_loss(self, state_batch):
qf1_pi = self.critic_local(state_batch)
qf2_pi = self.critic_local_2(state_batch)
min_qf_pi = torch.min(qf1_pi, qf2_pi)
inside_term = log_action_probabilities - min_qf_pi
inside_term = self.alpha * log_action_probabilities - min_qf_pi
policy_loss = torch.sum(action_probabilities * inside_term)
policy_loss = policy_loss.mean()
log_action_probabilities = log_action_probabilities.gather(1, action.unsqueeze(-1).long())
Expand Down
2 changes: 1 addition & 1 deletion exploration_strategies/Epsilon_Greedy_Exploration.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def perturb_action_for_exploration_purposes(self, action_info):

if (random.random() > epsilon or turn_off_exploration) and (episode_number >= self.random_episodes_to_run):
return torch.argmax(action_values).item()
return random.randint(0, action_values.shape[1] - 1)
return np.random.randint(0, action_values.shape[1])

def get_updated_epsilon_exploration(self, action_info, epsilon=1.0):
"""Gets the probability that we just pick a random action. This probability decays the more episodes we have seen"""
Expand Down
76 changes: 76 additions & 0 deletions utilities/Deepmind_RMS_Prop.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
from torch.optim import Optimizer


class DM_RMSprop(Optimizer):
"""Implements the form of RMSProp used in DM 2015 Atari paper.
Inspired by https://github.com/spragunr/deep_q_rl/blob/master/deep_q_rl/updates.py"""

def __init__(self, params, lr=1e-2, alpha=0.99, eps=1e-8, weight_decay=0, momentum=0, centered=False):
if not 0.0 <= lr:
raise ValueError("Invalid learning rate: {}".format(lr))
if not 0.0 <= eps:
raise ValueError("Invalid epsilon value: {}".format(eps))
if not 0.0 <= momentum:
raise ValueError("Invalid momentum value: {}".format(momentum))
if not 0.0 <= weight_decay:
raise ValueError("Invalid weight_decay value: {}".format(weight_decay))
if not 0.0 <= alpha:
raise ValueError("Invalid alpha value: {}".format(alpha))

defaults = dict(lr=lr, momentum=momentum, alpha=alpha, eps=eps, centered=centered, weight_decay=weight_decay)
super(DM_RMSprop, self).__init__(params, defaults)

def __setstate__(self, state):
super(DM_RMSprop, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('momentum', 0)
group.setdefault('centered', False)

def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
momentum = group['momentum']
sq_momentum = group['alpha']
epsilon = group['eps']

for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
if grad.is_sparse:
raise RuntimeError('RMSprop does not support sparse gradients')
state = self.state[p]

# State initialization
if len(state) == 0:
state['step'] = 0
state['square_avg'] = torch.zeros_like(p.data)
if momentum > 0:
state['momentum_buffer'] = torch.zeros_like(p.data)

mom_buffer = state['momentum_buffer']
square_avg = state['square_avg']


state['step'] += 1

mom_buffer.mul_(momentum)
mom_buffer.add_((1 - momentum) * grad)

square_avg.mul_(sq_momentum).addcmul_(1 - sq_momentum, grad, grad)

avg = (square_avg - mom_buffer**2 + epsilon).sqrt()

p.data.addcdiv_(-group['lr'], grad, avg)

return loss

2 changes: 1 addition & 1 deletion utilities/data_structures/Replay_Buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def separate_out_data_types(self, experiences):
return states, actions, rewards, next_states, dones

def pick_experiences(self, num_experiences=None):
if num_experiences: batch_size = num_experiences
if num_experiences is not None: batch_size = num_experiences
else: batch_size = self.batch_size
return random.sample(self.memory, k=batch_size)

Expand Down

0 comments on commit d4bdedf

Please sign in to comment.