Skip to content

Commit

Permalink
Add VizDoom PPO example and results (#533)
Browse files Browse the repository at this point in the history
* update vizdoom ppo example

* update README with results
  • Loading branch information
nuance1979 committed Feb 25, 2022
1 parent 23fbc3b commit 97df511
Show file tree
Hide file tree
Showing 8 changed files with 102 additions and 32 deletions.
29 changes: 25 additions & 4 deletions examples/vizdoom/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,14 +53,35 @@ python3 replay.py maps/D4_battle2.cfg results/c51/d4.lmp

See [maps/README.md](maps/README.md)

## Algorithms

The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.

## Reward

1. living reward is bad
2. combo-action is really important
3. negative reward for health and ammo2 is really helpful for d3/d4
4. only with positive reward for health is really helpful for d1
5. remove MOVE_BACKWARD may converge faster but the final performance may be lower

## Algorithms

The setting is exactly the same as Atari. You can definitely try more algorithms listed in Atari example.

### C51 (single run)

| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| D2_navigation | 747.52 | ![](results/c51/D2_navigation_rew.png) | `python3 vizdoom_c51.py --task "D2_navigation"` |
| D3_battle | 1855.29 | ![](results/c51/D3_battle_rew.png) | `python3 vizdoom_c51.py --task "D3_battle"` |

### PPO (single run)

| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| D2_navigation | 770.75 | ![](results/ppo/D2_navigation_rew.png) | `python3 vizdoom_ppo.py --task "D2_navigation"` |
| D3_battle | 320.59 | ![](results/ppo/D3_battle_rew.png) | `python3 vizdoom_ppo.py --task "D3_battle"` |

### PPO with ICM (single run)

| task | best reward | reward curve | parameters |
| --------------------------- | ----------- | ------------------------------------- | ------------------------------------------------------------ |
| D2_navigation | 844.99 | ![](results/ppo_icm/D2_navigation_rew.png) | `python3 vizdoom_ppo.py --task "D2_navigation" --icm-lr-scale 10` |
| D3_battle | 547.08 | ![](results/ppo_icm/D3_battle_rew.png) | `python3 vizdoom_ppo.py --task "D3_battle" --icm-lr-scale 10` |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/vizdoom/results/c51/D3_battle_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/vizdoom/results/ppo/D3_battle_rew.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@
import torch
from env import Env
from network import DQN
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.tensorboard import SummaryWriter

from tianshou.data import Collector, VectorReplayBuffer
from tianshou.env import ShmemVectorEnv
from tianshou.policy import A2CPolicy, ICMPolicy
from tianshou.policy import ICMPolicy, PPOPolicy
from tianshou.trainer import onpolicy_trainer
from tianshou.utils import TensorboardLogger
from tianshou.utils.net.common import ActorCritic
Expand All @@ -21,18 +22,28 @@ def get_args():
parser = argparse.ArgumentParser()
parser.add_argument('--task', type=str, default='D2_navigation')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument('--buffer-size', type=int, default=2000000)
parser.add_argument('--lr', type=float, default=0.0001)
parser.add_argument('--buffer-size', type=int, default=100000)
parser.add_argument('--lr', type=float, default=0.00002)
parser.add_argument('--gamma', type=float, default=0.99)
parser.add_argument('--epoch', type=int, default=300)
parser.add_argument('--step-per-epoch', type=int, default=100000)
parser.add_argument('--episode-per-collect', type=int, default=10)
parser.add_argument('--update-per-step', type=float, default=0.1)
parser.add_argument('--update-per-step', type=int, default=1)
parser.add_argument('--batch-size', type=int, default=64)
parser.add_argument('--hidden-sizes', type=int, nargs='*', default=[512])
parser.add_argument('--step-per-collect', type=int, default=1000)
parser.add_argument('--repeat-per-collect', type=int, default=4)
parser.add_argument('--batch-size', type=int, default=256)
parser.add_argument('--hidden-size', type=int, default=512)
parser.add_argument('--training-num', type=int, default=10)
parser.add_argument('--test-num', type=int, default=100)
parser.add_argument('--rew-norm', type=int, default=False)
parser.add_argument('--vf-coef', type=float, default=0.5)
parser.add_argument('--ent-coef', type=float, default=0.01)
parser.add_argument('--gae-lambda', type=float, default=0.95)
parser.add_argument('--lr-decay', type=int, default=True)
parser.add_argument('--max-grad-norm', type=float, default=0.5)
parser.add_argument('--eps-clip', type=float, default=0.2)
parser.add_argument('--dual-clip', type=float, default=None)
parser.add_argument('--value-clip', type=int, default=0)
parser.add_argument('--norm-adv', type=int, default=1)
parser.add_argument('--recompute-adv', type=int, default=0)
parser.add_argument('--logdir', type=str, default='log')
parser.add_argument('--render', type=float, default=0.)
parser.add_argument(
Expand Down Expand Up @@ -75,7 +86,7 @@ def get_args():
return parser.parse_args()


def test_a2c(args=get_args()):
def test_ppo(args=get_args()):
args.cfg_path = f"maps/{args.task}.cfg"
args.wad_path = f"maps/{args.task}.wad"
args.res = (args.skip_num, 84, 84)
Expand Down Expand Up @@ -105,33 +116,65 @@ def test_a2c(args=get_args()):
test_envs.seed(args.seed)
# define model
net = DQN(
*args.state_shape, args.action_shape, device=args.device, features_only=True
*args.state_shape,
args.action_shape,
device=args.device,
features_only=True,
output_dim=args.hidden_size
)
actor = Actor(
net, args.action_shape, hidden_sizes=args.hidden_sizes, device=args.device
)
critic = Critic(net, hidden_sizes=args.hidden_sizes, device=args.device)
actor = Actor(net, args.action_shape, device=args.device, softmax_output=False)
critic = Critic(net, device=args.device)
optim = torch.optim.Adam(ActorCritic(actor, critic).parameters(), lr=args.lr)

lr_scheduler = None
if args.lr_decay:
# decay learning rate to 0 linearly
max_update_num = np.ceil(
args.step_per_epoch / args.step_per_collect
) * args.epoch

lr_scheduler = LambdaLR(
optim, lr_lambda=lambda epoch: 1 - epoch / max_update_num
)

# define policy
dist = torch.distributions.Categorical
policy = A2CPolicy(actor, critic, optim, dist).to(args.device)
def dist(p):
return torch.distributions.Categorical(logits=p)

policy = PPOPolicy(
actor,
critic,
optim,
dist,
discount_factor=args.gamma,
gae_lambda=args.gae_lambda,
max_grad_norm=args.max_grad_norm,
vf_coef=args.vf_coef,
ent_coef=args.ent_coef,
reward_normalization=args.rew_norm,
action_scaling=False,
lr_scheduler=lr_scheduler,
action_space=env.action_space,
eps_clip=args.eps_clip,
value_clip=args.value_clip,
dual_clip=args.dual_clip,
advantage_normalization=args.norm_adv,
recompute_advantage=args.recompute_adv
).to(args.device)
if args.icm_lr_scale > 0:
feature_net = DQN(
*args.state_shape,
args.action_shape,
device=args.device,
features_only=True
features_only=True,
output_dim=args.hidden_size
)
action_dim = np.prod(args.action_shape)
feature_dim = feature_net.output_dim
icm_net = IntrinsicCuriosityModule(
feature_net.net,
feature_dim,
action_dim,
hidden_sizes=args.hidden_sizes,
device=args.device
feature_net.net, feature_dim, action_dim, device=args.device
)
icm_optim = torch.optim.adam(icm_net.parameters(), lr=args.lr)
icm_optim = torch.optim.Adam(icm_net.parameters(), lr=args.lr)
policy = ICMPolicy(
policy, icm_net, icm_optim, args.icm_lr_scale, args.icm_reward_scale,
args.icm_forward_loss_weight
Expand All @@ -153,7 +196,8 @@ def test_a2c(args=get_args()):
train_collector = Collector(policy, train_envs, buffer, exploration_noise=True)
test_collector = Collector(policy, test_envs, exploration_noise=True)
# log
log_path = os.path.join(args.logdir, args.task, 'a2c')
log_name = 'ppo_icm' if args.icm_lr_scale > 0 else 'ppo'
log_path = os.path.join(args.logdir, args.task, log_name)
writer = SummaryWriter(log_path)
writer.add_text("args", str(args))
logger = TensorboardLogger(writer)
Expand All @@ -162,10 +206,15 @@ def save_fn(policy):
torch.save(policy.state_dict(), os.path.join(log_path, 'policy.pth'))

def stop_fn(mean_rewards):
return False
if env.spec.reward_threshold:
return mean_rewards >= env.spec.reward_threshold
elif 'Pong' in args.task:
return mean_rewards >= 20
else:
return False

# watch agent's performance
def watch():
# watch agent's performance
print("Setup test envs ...")
policy.eval()
test_envs.seed(args.seed)
Expand Down Expand Up @@ -210,7 +259,7 @@ def watch():
args.repeat_per_collect,
args.test_num,
args.batch_size,
episode_per_collect=args.episode_per_collect,
step_per_collect=args.step_per_collect,
stop_fn=stop_fn,
save_fn=save_fn,
logger=logger,
Expand All @@ -222,4 +271,4 @@ def watch():


if __name__ == '__main__':
test_a2c(get_args())
test_ppo(get_args())

0 comments on commit 97df511

Please sign in to comment.