Skip to content

Commit

Permalink
Add sound obs to _get_obs
Browse files Browse the repository at this point in the history
  • Loading branch information
pseudo-rnd-thoughts authored Nov 27, 2024
1 parent 4be5e22 commit 9caf22c
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/ale/python/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
game sounds. This will lock emulation to the ROMs specified FPS
If `rgb_array` we'll return the `rgb` key in step metadata with
the current environment RGB frame.
sound_obs: bool => Use sound observation.
sound_obs: bool => Add teh sound from the frame to the observation.
Note:
- The game must be installed, see ale-import-roms, or ale-py-roms.
Expand Down Expand Up @@ -209,6 +209,12 @@ def __init__(
else:
raise error.Error(f"Unrecognized observation type: {self._obs_type}")

if self.sound_obs:
self.observation_space = spaces.Dict(
image=self.observation_space,
sound=spaces.Box(low=0, high=255, dtype=np.uint8, shape=(512,)),
)

def seed_game(self, seed: int | None = None) -> tuple[int, int]:
"""Seeds the internal and ALE RNG."""
ss = np.random.SeedSequence(seed)
Expand Down Expand Up @@ -321,19 +327,23 @@ def render(self) -> np.ndarray | None:
"Supported modes: `human`, `rgb_array`."
)

def _get_obs(self) -> np.ndarray:
def _get_obs(self) -> np.ndarray | dict[str, np.ndarray]:
"""Retrieves the current observation using `obs_type`."""
if self._obs_type == "ram":
return self.ale.getRAM()
image_obs = self.ale.getRAM()
elif self._obs_type == "rgb":
return self.ale.getScreenRGB()
image_obs = self.ale.getScreenRGB()
elif self._obs_type == "grayscale":
return self.ale.getScreenGrayscale()
image_obs = self.ale.getScreenGrayscale()
else:
raise error.Error(
f"Unrecognized observation type: {self._obs_type}, expected: 'ram', 'rgb' and 'grayscale'."
)

if self.sound_obs:
return {"image": image_obs, "sound": self.ale.getAudio()}
return image_obs

def _get_info(self) -> AtariEnvStepMetadata:
return {
"lives": self.ale.lives(),
Expand Down

0 comments on commit 9caf22c

Please sign in to comment.