-
Notifications
You must be signed in to change notification settings - Fork 1
/
gire_episode_runner.py
134 lines (101 loc) · 4.79 KB
/
gire_episode_runner.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from envs import REGISTRY as env_REGISTRY
from functools import partial
from components.episode_buffer import EpisodeBatch
import numpy as np
class GireEpisodeRunner:
def __init__(self, args, logger):
self.args = args
self.logger = logger
self.batch_size = self.args.batch_size_run
assert self.batch_size == 1
self.env = env_REGISTRY[self.args.env](**self.args.env_args)
self.episode_limit = self.env.episode_limit
self.t = 0
self.t_env = 0
self.train_returns = []
self.test_returns = []
self.train_stats = {}
self.test_stats = {}
# Log the first run
self.log_train_stats_t = -1000000
def setup(self, scheme, groups, preprocess, mac):
self.new_batch = partial(EpisodeBatch, scheme, groups, self.batch_size, self.episode_limit + 1,
preprocess=preprocess, device=self.args.device)
self.mac = mac
def get_env_info(self):
return self.env.get_env_info()
def save_replay(self):
self.env.save_replay()
def close_env(self):
self.env.close()
def reset(self):
self.batch = self.new_batch()
self.env.reset()
self.t = 0
def run(self, test_mode=False):
self.reset()
terminated = False
episode_return = 0
self.mac.init_hidden(batch_size=self.batch_size)
while not terminated:
pre_transition_data = {
"state": [self.env.get_state()],
"avail_actions": [self.env.get_avail_actions()],
"obs": [self.env.get_obs()]
}
self.batch.update(pre_transition_data, ts=self.t)
if not test_mode: # training模式,需要传入z
obs_input = self.mac._build_inputs(self.batch, t=self.t)
z_dist = self.mac.coach_net.forward(self.batch, obs_input, t=self.t, return_one=True) # coach_net
actions, _ = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env,
test_mode=test_mode, Z=z_dist.rsample())
else: # 测试模式,不需要z
actions, _ = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env,
test_mode=test_mode)
# Fix memory leak
cpu_actions = actions.to("cpu").numpy()
reward, terminated, env_info = self.env.step(actions[0])
episode_return += reward
post_transition_data = {
"actions": cpu_actions,
"reward": [(reward,)],
"terminated": [(terminated != env_info.get("episode_limit", False),)],
}
self.batch.update(post_transition_data, ts=self.t)
self.t += 1
last_data = {
"state": [self.env.get_state()],
"avail_actions": [self.env.get_avail_actions()],
"obs": [self.env.get_obs()]
}
self.batch.update(last_data, ts=self.t)
# Select actions in the last stored state
actions = self.mac.select_actions(self.batch, t_ep=self.t, t_env=self.t_env, test_mode=test_mode)
# Fix memory leak
cpu_actions = actions.to("cpu").numpy()
self.batch.update({"actions": cpu_actions}, ts=self.t)
cur_stats = self.test_stats if test_mode else self.train_stats
cur_returns = self.test_returns if test_mode else self.train_returns
log_prefix = "test_" if test_mode else ""
cur_stats.update({k: cur_stats.get(k, 0) + env_info.get(k, 0) for k in set(cur_stats) | set(env_info)})
cur_stats["n_episodes"] = 1 + cur_stats.get("n_episodes", 0)
cur_stats["ep_length"] = self.t + cur_stats.get("ep_length", 0)
if not test_mode:
self.t_env += self.t
cur_returns.append(episode_return)
if test_mode and (len(self.test_returns) == self.args.test_nepisode):
self._log(cur_returns, cur_stats, log_prefix)
elif self.t_env - self.log_train_stats_t >= self.args.runner_log_interval:
self._log(cur_returns, cur_stats, log_prefix)
if hasattr(self.mac.action_selector, "epsilon"):
self.logger.log_stat("epsilon", self.mac.action_selector.epsilon, self.t_env)
self.log_train_stats_t = self.t_env
return self.batch
def _log(self, returns, stats, prefix):
self.logger.log_stat(prefix + "return_mean", np.mean(returns), self.t_env)
self.logger.log_stat(prefix + "return_std", np.std(returns), self.t_env)
returns.clear()
for k, v in stats.items():
if k != "n_episodes":
self.logger.log_stat(prefix + k + "_mean" , v/stats["n_episodes"], self.t_env)
stats.clear()