-
Notifications
You must be signed in to change notification settings - Fork 0
/
play_a3c_model.py
111 lines (86 loc) · 3.28 KB
/
play_a3c_model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import gym
import random
import numpy as np
from skimage.color import rgb2gray
from skimage.transform import resize
from keras.models import Model
from keras.layers import Dense, Flatten, Input
from keras.layers.convolutional import Conv2D
episode = 0
EPISODES = 8000000
env_name = "BreakoutDeterministic-v4"
class TestAgent:
def __init__(self, action_size):
self.state_size = (84, 84, 4)
self.action_size = action_size
self.actor = self.build_model()
def build_model(self):
input = Input(shape=self.state_size)
conv = Conv2D(16, (8, 8), strides=(4, 4), activation='relu')(input)
conv = Conv2D(32, (4, 4), strides=(2, 2), activation='relu')(conv)
conv = Flatten()(conv)
fc = Dense(256, activation='relu')(conv)
policy = Dense(self.action_size, activation='softmax')(fc)
actor = Model(inputs=input, outputs=policy)
actor.summary()
return actor
def get_action(self, history):
history = np.float32(history / 255.)
policy = self.actor.predict(history)[0]
action_index = np.argmax(policy)
return action_index
def load_model(self, name):
self.actor.load_weights(name)
def pre_processing(next_observe, observe):
processed_observe = np.maximum(next_observe, observe)
processed_observe = np.uint8(resize(rgb2gray(processed_observe), (84, 84), mode='constant') * 255)
return processed_observe
if __name__ == '__main__':
env = gym.make(env_name)
agent = TestAgent(action_size=3)
agent.load_model('save_model/breakout_a3c_actor.h5')
step = 0
while episode < EPISODES:
done = False
dead = False
score, start_life = 0, 5
observe = env.reset()
next_observe = observe
for _ in range(random.randint(1, 20)):
observe = next_observe
next_observe, _, _, _ = env.step(1)
state = pre_processing(next_observe, observe)
history = np.stack((state, state, state, state), axis=2)
history = np.reshape([history], (1, 84, 84, 4))
while not done:
env.render()
step += 1
observe = next_observe
action = agent.get_action(history)
if action == 1:
fake_action = 2
elif action == 2:
fake_action = 3
else:
fake_action = 1
if dead:
fake_action = 1
dead = False
next_observe, reward, done, info = env.step(fake_action)
next_state = pre_processing(next_observe, observe)
next_state = np.reshape([next_state], (1, 84, 84, 1))
next_history = np.append(next_state, history[:, :, :, :3], axis=3)
if start_life > info['ale.lives']:
dead = True
reward = -1
start_life = info['ale.lives']
score += reward
if dead:
history = np.stack((next_state, next_state, next_state, next_state), axis=2)
history = np.reshape([history], (1, 84, 84, 4))
else:
history = next_history
if done:
episode += 1
print('episode:', episode, ' score:', score, ' step:', step)
step = 0