From 9caf22c3ccc3044d466f80916a1dc82bc4dcd823 Mon Sep 17 00:00:00 2001 From: Mark Towers Date: Wed, 27 Nov 2024 21:45:35 +0000 Subject: [PATCH] Add sound obs to `_get_obs` --- src/ale/python/env.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/src/ale/python/env.py b/src/ale/python/env.py index ce506107..29c2027b 100644 --- a/src/ale/python/env.py +++ b/src/ale/python/env.py @@ -74,7 +74,7 @@ def __init__( game sounds. This will lock emulation to the ROMs specified FPS If `rgb_array` we'll return the `rgb` key in step metadata with the current environment RGB frame. - sound_obs: bool => Use sound observation. + sound_obs: bool => Add teh sound from the frame to the observation. Note: - The game must be installed, see ale-import-roms, or ale-py-roms. @@ -209,6 +209,12 @@ def __init__( else: raise error.Error(f"Unrecognized observation type: {self._obs_type}") + if self.sound_obs: + self.observation_space = spaces.Dict( + image=self.observation_space, + sound=spaces.Box(low=0, high=255, dtype=np.uint8, shape=(512,)), + ) + def seed_game(self, seed: int | None = None) -> tuple[int, int]: """Seeds the internal and ALE RNG.""" ss = np.random.SeedSequence(seed) @@ -321,19 +327,23 @@ def render(self) -> np.ndarray | None: "Supported modes: `human`, `rgb_array`." ) - def _get_obs(self) -> np.ndarray: + def _get_obs(self) -> np.ndarray | dict[str, np.ndarray]: """Retrieves the current observation using `obs_type`.""" if self._obs_type == "ram": - return self.ale.getRAM() + image_obs = self.ale.getRAM() elif self._obs_type == "rgb": - return self.ale.getScreenRGB() + image_obs = self.ale.getScreenRGB() elif self._obs_type == "grayscale": - return self.ale.getScreenGrayscale() + image_obs = self.ale.getScreenGrayscale() else: raise error.Error( f"Unrecognized observation type: {self._obs_type}, expected: 'ram', 'rgb' and 'grayscale'." ) + if self.sound_obs: + return {"image": image_obs, "sound": self.ale.getAudio()} + return image_obs + def _get_info(self) -> AtariEnvStepMetadata: return { "lives": self.ale.lives(),