"""Contains `wrappers` that can wrap around environments to modify their functionality.
The implementations of these wrappers are adopted from
`OpenAI <https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py>`_.
"""
import time
import cv2
import gym
import numpy as np
cv2.ocl.setUseOpenCL(False) # do not use OpenCL
[docs]class AtariPreprocessFrameWrapper(gym.ObservationWrapper):
"""A wrapper that scales the observations from 210x160 down to 84x84 and converts from RGB to grayscale by
extracting the luminance.
"""
[docs] def __init__(self, env):
"""
Args:
env (:obj:`gym.Env`):
An environment that will be wrapped.
"""
super().__init__(env)
self.observation_space = gym.spaces.Box(low=0, high=255, shape=(84, 84, 1), dtype=np.uint8)
def observation(self, frame):
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY) # convert to grayscale
frame = cv2.resize(frame, (84, 84), interpolation=cv2.INTER_AREA) # scale down
return np.expand_dims(frame, axis=-1)
[docs]class AtariFrameskipWrapper(gym.Wrapper):
"""A wrapper that skips frames.
"""
[docs] def __init__(self, env, frameskip):
"""
Args:
env (:obj:`gym.Env`):
An environment that will be wrapped.
frameskip (:obj:`int`):
Every `frameskip`-th frame is used. The remaining frames are skipped.
"""
super().__init__(env)
self._frameskip = frameskip
[docs] def step(self, action):
frames = []
total_reward = 0.0
terminal = False
info = None
for i in range(self._frameskip):
next_frame, reward, terminal, info = self.env.step(action)
frames.append(next_frame)
total_reward += reward
if terminal:
break
if len(frames) >= 2:
return np.amax((frames[-2], frames[-1]), axis=0), total_reward, terminal, info
else:
return frames[0], total_reward, terminal, info
[docs] def reset(self, **kwargs):
return self.env.reset(**kwargs)
[docs]class AtariClipRewardWrapper(gym.RewardWrapper):
"""A wrapper that clips the rewards between -1 and 1.
"""
[docs] def __init__(self, env):
"""
Args:
env (:obj:`gym.Env`):
An environment that will be wrapped.
"""
super().__init__(env)
def reward(self, reward):
return np.clip(reward, -1., 1.)
[docs]class AtariEpisodicLifeWrapper(gym.Wrapper):
"""A wrapper that ends episodes (returns `terminal` = True) after a life in the Atari game has been lost.
"""
[docs] def __init__(self, env):
"""
Args:
env (:obj:`gym.Env`):
An environment that will be wrapped.
"""
super().__init__(env)
self.lives = 0
self.episode_terminal = True
[docs] def step(self, action):
next_observation, reward, terminal, info = self.env.step(action)
self.episode_terminal = terminal
next_lives = info['ale.lives']
if next_lives < self.lives:
terminal = True
self.lives = next_lives
return next_observation, reward, terminal, info
[docs] def reset(self, **kwargs):
if self.episode_terminal:
self.env.reset(**kwargs)
observation, _, terminal, info = self.env.step(0) # ACTION_NOOP
self.lives = info['ale.lives']
return observation
[docs]class AtariFireResetWrapper(gym.Wrapper):
"""A wrapper that executes the `'FIRE'` action after the environment has been reset.
"""
[docs] def __init__(self, env):
"""
Args:
env (:obj:`gym.Env`):
An environment that will be wrapped.
"""
super().__init__(env)
[docs] def step(self, action):
return self.env.step(action)
[docs] def reset(self, **kwargs):
# TODO
self.env.reset(**kwargs)
observation, _, terminal, _ = self.env.step(1) # ACTION_FIRE
if terminal:
print('WARNING')
observation = self.env.reset(**kwargs)
return observation
[docs]class AtariNoopResetWrapper(gym.Wrapper):
"""A wrapper that executes a random number of `'NOOP'` actions.
"""
[docs] def __init__(self, env, noop_max):
"""
Args:
env (:obj:`gym.Env`):
An environment that will be wrapped.
noop_max (:obj:`int`):
The maximum number of `'NOOP'` actions. The number is selected randomly between 1 and `noop_max`.
"""
super().__init__(env)
self.noop_max = noop_max
[docs] def step(self, action):
return self.env.step(action)
[docs] def reset(self, **kwargs):
observation = self.env.reset(**kwargs)
num_noops = self.unwrapped.np_random.randint(1, self.noop_max + 1)
for _ in range(num_noops):
observation, _, terminal, _ = self.env.step(0) # ACTION_NOOP
if terminal:
observation = self.env.reset(**kwargs)
return observation
[docs]class RenderWrapper(gym.Wrapper):
"""A wrapper that calls :meth:`gym.Env.render` every step.
"""
[docs] def __init__(self, env, fps=None):
"""
Args:
env (:obj:`gym.Env`):
An environment that will be wrapped.
fps (:obj:`int`, :obj:`float`, optional):
If it is not None, the steps will be slowed down to run at the specified frames per second by waiting
1.0/`fps` seconds every step.
"""
super().__init__(env)
self._spf = 1.0 / fps if fps is not None else None
[docs] def step(self, action):
self.env.render()
if self._spf is not None:
time.sleep(self._spf)
return self.env.step(action)
[docs] def reset(self, **kwargs):
return self.env.reset(**kwargs)
[docs]class FrameStackWrapper(gym.Wrapper):
"""A wrapper that stacks the last observations. The observations returned by this wrapper consist of the last
frames.
"""
[docs] def __init__(self, env, num_stacked_frames):
"""
Args:
env (:obj:`gym.Env`):
An environment that will be wrapped.
num_stacked_frames (:obj:`int`):
The number of frames that will be stacked.
"""
super().__init__(env)
self._num_stacked_frames = num_stacked_frames
stacked_low = np.repeat(env.observation_space.low, num_stacked_frames, axis=-1)
stacked_high = np.repeat(env.observation_space.low, num_stacked_frames, axis=-1)
self.observation_space = gym.spaces.Box(low=stacked_low, high=stacked_high, dtype=env.observation_space.dtype)
self._stacked_frames = np.zeros_like(stacked_low)
[docs] def step(self, action):
next_frame, reward, terminal, info = self.env.step(action)
self._stacked_frames = np.roll(self._stacked_frames, shift=-1, axis=-1)
if terminal:
self._stacked_frames.fill(0.0)
self._stacked_frames[..., -1:] = next_frame
return self._stacked_frames, reward, terminal, info
[docs] def reset(self, **kwargs):
frame = self.env.reset(**kwargs)
self._stacked_frames = np.repeat(frame, self._num_stacked_frames, axis=-1)
return self._stacked_frames
[docs]class AtariInfoClearWrapper(gym.Wrapper):
"""A wrapper that removes unnecessary data in the `info` returned by :meth:`gym.Env.step`. This reduces the amount
of inter-process data.
Warning:
:obj:`AtariEpisodicLifeWrapper` does not work afterwards, so it should be used `before`.
"""
[docs] def __init__(self, env):
"""
Args:
env (:obj:`gym.Env`):
An environment that will be wrapped.
"""
super().__init__(env)
[docs] def step(self, action):
observation, reward, terminal, info = self.env.step(action)
del info['ale.lives']
return observation, reward, terminal, info
[docs] def reset(self, **kwargs):
return self.env.reset(**kwargs)
[docs]class EpisodeInfoWrapper(gym.Wrapper):
"""A wrapper that stores episode information in the `info` returned by :meth:`gym.Env.step` at the end of an
episode. More specifically, if an episode is terminal, `info` will contain the key `'episode'` which has a
:obj:`dict` value containing the `'total_reward'`, which is the cumulative reward of the episode.
Note:
If you want to get the cumulative reward of the entire episode, :obj:`AtariEpisodicLifeWrapper` should be used
`after` this wrapper.
"""
[docs] def __init__(self, env):
"""
Args:
env (:obj:`gym.Env`):
An environment that will be wrapped.
"""
super().__init__(env)
self.total_reward = 0.0
[docs] def step(self, action):
observation, reward, terminal, info = self.env.step(action)
self.total_reward += reward
if terminal:
episode_info = dict()
episode_info['total_reward'] = self.total_reward
info['episode'] = episode_info
self.total_reward = 0.0
return observation, reward, terminal, info
[docs] def reset(self, **kwargs):
self.total_reward = 0.0
return self.env.reset(**kwargs)
[docs] @staticmethod
def get_episode_rewards_from_info_batch(infos):
"""Utility function that extracts the episode rewards, that are inserted by the :obj:`EpisodeInfoWrapper`, out
of the `infos`.
Args:
infos (:obj:`list` of :obj:`list`):
A batch-major list of `infos` as returned by :meth:`~actorcritic.agents.Agent.interact`.
Returns:
:obj:`numpy.ndarray`:
A batch-major array with the same shape as infos. It contains the episode reward of an `info` at the
corresponding position. If no episode reward was in an `info`, the result will contain
:obj:`numpy.nan` respectively.
"""
rewards = np.full_like(infos, np.nan, np.float32)
environments, steps = rewards.shape
for environment in range(environments):
for step in range(steps):
info = infos[environment][step]
if 'episode' in info:
reward = info['episode']['total_reward']
rewards[environment, step] = reward
return rewards