diff --git a/vmas/scenarios/discovery.py b/vmas/scenarios/discovery.py index ef57735c..c3934503 100644 --- a/vmas/scenarios/discovery.py +++ b/vmas/scenarios/discovery.py @@ -28,7 +28,11 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self._min_dist_between_entities = kwargs.pop("min_dist_between_entities", 0.2) self._lidar_range = kwargs.pop("lidar_range", 0.35) self._covering_range = kwargs.pop("covering_range", 0.25) + self.use_agent_lidar = kwargs.pop("use_agent_lidar", False) + self.n_lidar_rays_entities = kwargs.pop("n_lidar_rays_entities", 15) + self.n_lidar_rays_agents = kwargs.pop("n_lidar_rays_agents", 12) + self._agents_per_target = kwargs.pop("agents_per_target", 2) self.targets_respawn = kwargs.pop("targets_respawn", True) self.shared_reward = kwargs.pop("shared_reward", False) @@ -74,7 +78,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): [ Lidar( world, - n_rays=15, + n_rays=self.n_lidar_rays_entities, max_range=self._lidar_range, entity_filter=entity_filter_targets, render_color=Color.GREEN, @@ -86,7 +90,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): world, angle_start=0.05, angle_end=2 * torch.pi + 0.05, - n_rays=12, + n_rays=self.n_lidar_rays_agents, max_range=self._lidar_range, entity_filter=entity_filter_agents, render_color=Color.BLUE, diff --git a/vmas/scenarios/flocking.py b/vmas/scenarios/flocking.py index 713b21e0..c15d3f24 100644 --- a/vmas/scenarios/flocking.py +++ b/vmas/scenarios/flocking.py @@ -20,6 +20,8 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): n_obstacles = kwargs.pop("n_obstacles", 5) self._min_dist_between_entities = kwargs.pop("min_dist_between_entities", 0.15) + self.n_lidar_rays = kwargs.pop("n_lidar_rays", 12) + self.collision_reward = kwargs.pop("collision_reward", -0.1) self.dist_shaping_factor = kwargs.pop("dist_shaping_factor", 1) ScenarioUtils.check_kwargs_consumed(kwargs) @@ -51,7 +53,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): sensors=[ Lidar( world, - n_rays=12, + n_rays=self.n_lidar_rays, max_range=0.2, entity_filter=goal_entity_filter, ) diff --git a/vmas/scenarios/navigation.py b/vmas/scenarios/navigation.py index 9943215c..cbef4234 100644 --- a/vmas/scenarios/navigation.py +++ b/vmas/scenarios/navigation.py @@ -41,6 +41,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): self.lidar_range = kwargs.pop("lidar_range", 0.35) self.agent_radius = kwargs.pop("agent_radius", 0.1) self.comms_range = kwargs.pop("comms_range", 0) + self.n_lidar_rays = kwargs.pop("n_lidar_rays", 12) self.shared_rew = kwargs.pop("shared_rew", True) self.pos_shaping_factor = kwargs.pop("pos_shaping_factor", 1) @@ -115,7 +116,7 @@ def make_world(self, batch_dim: int, device: torch.device, **kwargs): [ Lidar( world, - n_rays=12, + n_rays=self.n_lidar_rays, max_range=self.lidar_range, entity_filter=entity_filter_agents, ),