Skip to content

Commit

Permalink
TerminateIllegalWrapper fix (Farama-Foundation#1206)
Browse files Browse the repository at this point in the history
  • Loading branch information
dm-ackerman authored Jun 21, 2024
1 parent 9f441fe commit 1eef080
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 9 deletions.
11 changes: 4 additions & 7 deletions pettingzoo/utils/wrappers/terminate_illegal.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# pyright reportGeneralTypeIssues=false
from __future__ import annotations

from pettingzoo.utils.env import ActionType, AECEnv, AgentID, ObsType
Expand All @@ -20,6 +19,7 @@ def __init__(
self._illegal_value = illegal_reward
self._prev_obs = None
self._prev_info = None
self._terminated = False # terminated by an illegal move

def reset(self, seed: int | None = None, options: dict | None = None) -> None:
self._terminated = False
Expand All @@ -42,7 +42,6 @@ def step(self, action: ActionType) -> None:
if self._prev_obs is None:
self.observe(self.agent_selection)
if isinstance(self._prev_obs, dict):
assert self._prev_obs is not None
assert (
"action_mask" in self._prev_obs
), f"`action_mask` not found in dictionary observation: {self._prev_obs}. Action mask must either be in `observation['action_mask']` or `info['action_mask']` to use TerminateIllegalWrapper."
Expand All @@ -60,7 +59,7 @@ def step(self, action: ActionType) -> None:
self.terminations[self.agent_selection]
or self.truncations[self.agent_selection]
):
self._was_dead_step(action) # pyright: ignore[reportGeneralTypeIssues]
self.env.unwrapped._was_dead_step(action)
elif (
not self.terminations[self.agent_selection]
and not self.truncations[self.agent_selection]
Expand All @@ -70,12 +69,10 @@ def step(self, action: ActionType) -> None:
self.env.unwrapped._cumulative_rewards[self.agent_selection] = 0
self.env.unwrapped.terminations = {d: True for d in self.agents}
self.env.unwrapped.truncations = {d: True for d in self.agents}
self._prev_obs = None
self._prev_info = None
self.env.unwrapped.rewards = {d: 0 for d in self.truncations}
self.env.unwrapped.rewards[current_agent] = float(self._illegal_value)
self._accumulate_rewards()
self._deads_step_first()
self.env.unwrapped._accumulate_rewards()
self.env.unwrapped._deads_step_first()
self._terminated = True
else:
super().step(action)
Expand Down
71 changes: 69 additions & 2 deletions test/wrapper_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,13 @@
import pytest

from pettingzoo.butterfly import pistonball_v6
from pettingzoo.classic import texas_holdem_no_limit_v6
from pettingzoo.utils.wrappers import MultiEpisodeEnv, MultiEpisodeParallelEnv
from pettingzoo.classic import texas_holdem_no_limit_v6, tictactoe_v3
from pettingzoo.utils.wrappers import (
BaseWrapper,
MultiEpisodeEnv,
MultiEpisodeParallelEnv,
TerminateIllegalWrapper,
)


@pytest.mark.parametrize(("num_episodes"), [1, 2, 3, 4, 5, 6])
Expand Down Expand Up @@ -67,3 +72,65 @@ def test_multi_episode_parallel_env_wrapper(num_episodes) -> None:
assert (
steps == num_episodes * 125
), f"Expected to have 125 steps per episode, got {steps / num_episodes}."


def _do_game(env: TerminateIllegalWrapper, seed: int) -> None:
"""Run a single game with reproducible random moves."""
assert isinstance(
env, TerminateIllegalWrapper
), "test_terminate_illegal must use TerminateIllegalWrapper"
env.reset(seed)
for agent in env.agents:
# make the random moves reproducible
env.action_space(agent).seed(seed)

for agent in env.agent_iter():
_, _, termination, truncation, _ = env.last()

if termination or truncation:
env.step(None)
else:
action = env.action_space(agent).sample()
env.step(action)


def test_terminate_illegal() -> None:
"""Test for a problem with terminate illegal wrapper.
The problem is that env variables, including agent_selection, are set by
calls from TerminateIllegalWrapper to env functions. However, they are
called by the wrapper object, not the env so they are set in the wrapper
object rather than the base env object. When the code later tries to run,
the values get updated in the env code, but the wrapper pulls it's own
values that shadow them.
The test here confirms that is fixed.
"""
# not using env() because we need to ensure that the env is
# wrapped by TerminateIllegalWrapper
raw_env = tictactoe_v3.raw_env()
env = TerminateIllegalWrapper(raw_env, illegal_reward=-1)

_do_game(env, 42)
# bug is triggered by a corrupted state after a game is terminated
# due to an illegal move. So we need to run the game twice to
# see the effect.
_do_game(env, 42)

# get a list of what all the agent_selection values in the wrapper stack
unwrapped = env
agent_selections = []
while unwrapped != env.unwrapped:
# the actual value for this wrapper (or None if no value)
agent_selections.append(unwrapped.__dict__.get("agent_selection", None))
assert isinstance(unwrapped, BaseWrapper)
unwrapped = unwrapped.env

# last one from the actual env
agent_selections.append(unwrapped.__dict__.get("agent_selection", None))

# remove None from agent_selections
agent_selections = [x for x in agent_selections if x is not None]

# all values must be the same, or else the wrapper and env are mismatched
assert len(set(agent_selections)) == 1, "agent_selection mismatch"

0 comments on commit 1eef080

Please sign in to comment.