Skip to content

Commit

Permalink
[BugFix] Discovery obs (#137)
Browse files Browse the repository at this point in the history
* amend

* amend

* amend
  • Loading branch information
matteobettini authored Aug 30, 2024
1 parent 73bb583 commit ff58363
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 41 deletions.
5 changes: 3 additions & 2 deletions tests/test_scenarios/test_discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ def setup_env(
self.env.seed(0)

@pytest.mark.parametrize("n_agents", [1, 4])
def test_heuristic(self, n_agents, n_steps=50, n_envs=4):
self.setup_env(n_agents=n_agents, n_envs=n_envs)
@pytest.mark.parametrize("agent_lidar", [True, False])
def test_heuristic(self, n_agents, agent_lidar, n_steps=50, n_envs=4):
self.setup_env(n_agents=n_agents, n_envs=n_envs, use_agent_lidar=agent_lidar)
policy = discovery.HeuristicPolicy(True)

obs = self.env.reset()
Expand Down
82 changes: 43 additions & 39 deletions vmas/scenarios/discovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
self._min_dist_between_entities = kwargs.pop("min_dist_between_entities", 0.2)
self._lidar_range = kwargs.pop("lidar_range", 0.35)
self._covering_range = kwargs.pop("covering_range", 0.25)
self.use_agent_lidar = kwargs.pop("use_agent_lidar", False)
self._agents_per_target = kwargs.pop("agents_per_target", 2)
self.targets_respawn = kwargs.pop("targets_respawn", True)
self.shared_reward = kwargs.pop("shared_reward", False)
Expand Down Expand Up @@ -57,9 +58,9 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
)

# Add agents
# entity_filter_agents: Callable[[Entity], bool] = lambda e: e.name.startswith(
# "agent"
# )
entity_filter_agents: Callable[[Entity], bool] = lambda e: e.name.startswith(
"agent"
)
entity_filter_targets: Callable[[Entity], bool] = lambda e: e.name.startswith(
"target"
)
Expand All @@ -69,24 +70,32 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs):
name=f"agent_{i}",
collide=True,
shape=Sphere(radius=self.agent_radius),
sensors=[
# Lidar(
# world,
# angle_start=0.05,
# angle_end=2 * torch.pi + 0.05,
# n_rays=12,
# max_range=self._lidar_range,
# entity_filter=entity_filter_agents,
# render_color=Color.BLUE,
# ),
Lidar(
world,
n_rays=15,
max_range=self._lidar_range,
entity_filter=entity_filter_targets,
render_color=Color.GREEN,
),
],
sensors=(
[
Lidar(
world,
n_rays=15,
max_range=self._lidar_range,
entity_filter=entity_filter_targets,
render_color=Color.GREEN,
)
]
+ (
[
Lidar(
world,
angle_start=0.05,
angle_end=2 * torch.pi + 0.05,
n_rays=12,
max_range=self._lidar_range,
entity_filter=entity_filter_agents,
render_color=Color.BLUE,
)
]
if self.use_agent_lidar
else []
)
),
)
agent.collision_rew = torch.zeros(batch_dim, device=device)
agent.covering_reward = agent.collision_rew.clone()
Expand Down Expand Up @@ -230,15 +239,9 @@ def agent_reward(self, agent):

def observation(self, agent: Agent):
lidar_1_measures = agent.sensors[0].measure()
# lidar_2_measures = agent.sensors[1].measure()
return torch.cat(
[
agent.state.pos,
agent.state.vel,
agent.state.pos,
lidar_1_measures,
# lidar_2_measures,
],
[agent.state.pos, agent.state.vel, lidar_1_measures]
+ ([agent.sensors[1].measure()] if self.use_agent_lidar else []),
dim=-1,
)

Expand Down Expand Up @@ -317,24 +320,25 @@ def compute_action(self, observation: torch.Tensor, u_range: float) -> torch.Ten
closest_point_on_circ_normal *= 0.1
des_pos = closest_point_on_circ + closest_point_on_circ_normal

# Move away from other agents within visibility range
lidar_agents = observation[:, 4:16]
agent_visible = torch.any(lidar_agents < 0.15, dim=1)
_, agent_dir_index = torch.min(lidar_agents, dim=1)
agent_dir = agent_dir_index / lidar_agents.shape[1] * 2 * torch.pi
agent_vec = torch.stack([torch.cos(agent_dir), torch.sin(agent_dir)], dim=1)
des_pos_agent = current_pos - agent_vec * 0.1
des_pos[agent_visible] = des_pos_agent[agent_visible]

# Move towards targets within visibility range
lidar_targets = observation[:, 16:28]
lidar_targets = observation[:, 4:19]
target_visible = torch.any(lidar_targets < 0.3, dim=1)
_, target_dir_index = torch.min(lidar_targets, dim=1)
target_dir = target_dir_index / lidar_targets.shape[1] * 2 * torch.pi
target_vec = torch.stack([torch.cos(target_dir), torch.sin(target_dir)], dim=1)
des_pos_target = current_pos + target_vec * 0.1
des_pos[target_visible] = des_pos_target[target_visible]

if observation.shape[-1] > 19:
# Move away from other agents within visibility range
lidar_agents = observation[:, 19:31]
agent_visible = torch.any(lidar_agents < 0.15, dim=1)
_, agent_dir_index = torch.min(lidar_agents, dim=1)
agent_dir = agent_dir_index / lidar_agents.shape[1] * 2 * torch.pi
agent_vec = torch.stack([torch.cos(agent_dir), torch.sin(agent_dir)], dim=1)
des_pos_agent = current_pos - agent_vec * 0.1
des_pos[agent_visible] = des_pos_agent[agent_visible]

action = torch.clamp(
(des_pos - current_pos) * 10,
min=-u_range,
Expand Down

0 comments on commit ff58363

Please sign in to comment.