diff --git a/minigrid/wrappers.py b/minigrid/wrappers.py index 569fa11d0..814c6e5f5 100644 --- a/minigrid/wrappers.py +++ b/minigrid/wrappers.py @@ -569,19 +569,17 @@ class FlatObsWrapper(ObservationWrapper): (2835,) """ - def __init__(self, env, maxStrLen=96): + def __init__(self, env, maxStrLen: int = 96): super().__init__(env) self.maxStrLen = maxStrLen self.numCharCodes = 28 - imgSpace = env.observation_space.spaces["image"] - imgSize = reduce(operator.mul, imgSpace.shape, 1) - + img_size = np.prod(env.observation_space["image"].shape) self.observation_space = spaces.Box( low=0, high=255, - shape=(imgSize + self.numCharCodes * self.maxStrLen,), + shape=(img_size + self.numCharCodes * self.maxStrLen,), dtype="uint8", ) @@ -598,12 +596,11 @@ def observation(self, obs): ), f"mission string too long ({len(mission)} chars)" mission = mission.lower() - strArray = np.zeros( - shape=(self.maxStrLen, self.numCharCodes), dtype="float32" - ) + str_array = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype="uint8") + # as `numCharCodes` < 255 then we can use `uint8` for idx, ch in enumerate(mission): - if ch >= "a" and ch <= "z": + if "a" <= ch <= "z": chNo = ord(ch) - ord("a") elif ch == " ": chNo = ord("z") - ord("a") + 1 @@ -613,11 +610,11 @@ def observation(self, obs): raise ValueError( f"Character {ch} is not available in mission string." ) - assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo) - strArray[idx, chNo] = 1 + assert chNo < self.numCharCodes, f"{ch} : {chNo:d}" + str_array[idx, chNo] = 1 self.cachedStr = mission - self.cachedArray = strArray + self.cachedArray = str_array obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))