Skip to content

Commit

Permalink
Rewrite episode recorder for efficiency.
Browse files Browse the repository at this point in the history
  • Loading branch information
danijar committed Sep 5, 2021
1 parent f482532 commit 3a95961
Show file tree
Hide file tree
Showing 2 changed files with 148 additions and 63 deletions.
209 changes: 147 additions & 62 deletions crafter/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,89 +9,174 @@
class Recorder:

def __init__(
self, env, directory=None, save_stats=True, save_video=True,
self, env, directory, save_stats=True, save_video=True,
save_episode=True, video_size=(512, 512)):
if directory and save_stats:
env = StatsRecorder(env, directory)
if directory and save_video:
env = VideoRecorder(env, directory, video_size)
if directory and save_episode:
env = EpisodeRecorder(env, directory)
self._env = env
self._directory = directory and pathlib.Path(directory)
self._save_stats = save_stats
self._save_episode = save_episode
self._save_video = save_video
self._video_size = video_size
self._frames = []

def __getattr__(self, name):
if name.startswith('__'):
raise AttributeError(name)
return getattr(self._env, name)


class StatsRecorder:

def __init__(self, env, directory):
self._env = env
self._directory = pathlib.Path(directory)
self._directory.mkdir(exist_ok=True, parents=True)
self._file = (self._directory / 'stats.jsonl').open('a')
self._length = None
self._unlocked = None
self._stats = None

def __getattr__(self, name):
if name.startswith('__'):
raise AttributeError(name)
return getattr(self._env, name)

def reset(self):
obs = self._env.reset()
self._length = 0
self._unlocked = None
self._stats = None
return obs

def step(self, action):
obs, reward, done, info = self._env.step(action)
self._length += 1
if done:
self._stats = {'length': self._length}
for key, value in info['achievements'].items():
self._stats[f'achievement_{key}'] = value
self._save()
return obs, reward, done, info

def _save(self):
self._file.write(json.dumps(self._stats) + '\n')
self._file.flush()


class VideoRecorder:

def __init__(self, env, directory, size=(512, 512)):
if not hasattr(env, 'episode_name'):
env = EpisodeName(env)
self._env = env
self._directory = pathlib.Path(directory)
self._directory.mkdir(exist_ok=True, parents=True)
self._size = size
self._frames = None

def __getattr__(self, name):
if name.startswith('__'):
raise AttributeError(name)
return getattr(self._env, name)

def reset(self):
obs = self._env.reset()
self._frames = [self._env.render(self._size)]
return obs

def step(self, action):
obs, reward, done, info = self._env.step(action)
self._frames.append(self._env.render(self._size))
if done:
self._save()
return obs, reward, done, info

def _save(self):
filename = str(self._directory / (self._env.episode_name + '.mp4'))
imageio.mimsave(self._directory / filename, self._frames)


class EpisodeRecorder:

def __init__(self, env, directory):
if not hasattr(env, 'episode_name'):
env = EpisodeName(env)
self._env = env
self._directory = pathlib.Path(directory)
self._directory.mkdir(exist_ok=True, parents=True)
self._episode = None
if self._directory:
self._directory.mkdir(exist_ok=True, parents=True)
if self._save_stats:
# We keep the file handle open because some cloud storage systems
# cannot handle quickly re-opening the file many times.
self._file_handle = (self._directory / 'stats.jsonl').open('a')

def __getattr__(self, name):
if name.startswith('__'):
raise AttributeError(name)
try:
return getattr(self._env, name)
except AttributeError:
raise ValueError(name)
return getattr(self._env, name)

def reset(self):
# The first time step only contains the initial image that the environment
# returns on reset.
obs = self._env.reset()
self._frames = [self._env.render(self._video_size)]
self._episode = [{'image': obs}]
return obs

def step(self, action):
# Each time step contains the action and the quantities provided by the
# environment in response to the action.
# Transitions are defined from the environment perspective, meaning that a
# transition contains the action and the resulting reward and next
# observation produced by the environment in response to said action.
obs, reward, done, info = self._env.step(action)
self._frames.append(self._env.render(self._video_size))
details = info.copy()
details.update({
f'inventory_{k}': v for k, v
in details.pop('inventory').items()})
details.update({
f'achievement_{k}': v for k, v
in details.pop('achievements').items()})
self._episode.append({
'action': action,
'image': obs,
'reward': reward,
'done': done,
**details,
})
if done and self._directory:
transition = {
'action': action, 'image': obs, 'reward': reward, 'done': done,
}
for key, value in info.items():
if key in ('inventory', 'achievements'):
continue
transition[key] = value
for key, value in info['achievements'].items():
transition[f'achievement_{key}'] = value
for key, value in info['inventory'].items():
transition[f'ainventory_{key}'] = value
self._episode.append(transition)
if done:
self._save()
return obs, reward, done, info

def episode(self):
# Fill in keys for the first time step of the episode.
def _save(self):
filename = str(self._directory / (self._env.episode_name + '.npz'))
# Fill in zeros for keys missing at the first time step.
for key, value in self._episode[1].items():
if key not in self._episode[0]:
self._episode[0][key] = np.zeros_like(value)
return {
episode = {
k: np.array([step[k] for step in self._episode])
for k in self._episode[0]}
np.savez_compressed(filename, **episode)

def _save(self):
eps = self.episode()
score = round(eps['reward'].sum(), 1)
length = len(eps['reward'])
ach = {
k: v[-1].item() for k, v in eps.items()
if k.startswith('achievement_')}
unlocked = sum(int(v > 0) for v in ach.values())
time = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
path = str(self._directory / f'{time}-ach{unlocked:02}-len{length:05}')
if self._save_stats:
content = {'score': score, 'length': length}
content.update(**{
k: v[-1].item() for k, v in eps.items()
if k.startswith('achievement_')})
self._file_handle.write(json.dumps(content) + '\n')
self._file_handle.flush()
if self._save_video:
imageio.mimsave(path + '.mp4', self._frames)
if self._save_episode:
np.savez_compressed(path + '.npz', **eps)

class EpisodeName:

def __init__(self, env):
self._env = env
self._timestamp = None
self._unlocked = None
self._length = None

def __getattr__(self, name):
if name.startswith('__'):
raise AttributeError(name)
return getattr(self._env, name)

def reset(self):
obs = self._env.reset()
self._timestamp = None
self._unlocked = None
self._length = 0
return obs

def step(self, action):
obs, reward, done, info = self._env.step(action)
self._length += 1
if done:
self._timestamp = datetime.datetime.now().strftime('%Y%m%dT%H%M%S')
self._unlocked = sum(int(v >= 1) for v in info['achievements'].values())
return obs, reward, done, info

@property
def episode_name(self):
return f'{self._timestamp}-ach{self._unlocked}-len{self._length}'
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

setuptools.setup(
name='crafter',
version='1.4.1',
version='1.5.0',
description='Open world survival game for reinforcement learning.',
url='http://github.com/danijar/crafter',
long_description=pathlib.Path('README.md').read_text(),
Expand Down

0 comments on commit 3a95961

Please sign in to comment.