Skip to content

Commit 7b18200

Browse files
Add randomized goal option
1 parent 5e2489b commit 7b18200

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

example.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import gymnasium as gym
22
import gym_pusht
33

4-
env = gym.make("gym_pusht/PushT-v0", render_mode="human")
4+
env = gym.make("gym_pusht/PushT-v1", render_mode="human")
55
observation, info = env.reset()
66

77
for _ in range(1000):

gym_pusht/__init__.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,13 @@
44
id="gym_pusht/PushT-v0",
55
entry_point="gym_pusht.envs:PushTEnv",
66
max_episode_steps=300,
7-
kwargs={"obs_type": "state"},
7+
kwargs={"obs_type": "state", "randomize_goal": False},
8+
)
9+
10+
# Register a version with randomized goal
11+
register(
12+
id="gym_pusht/PushT-v1",
13+
entry_point="gym_pusht.envs:PushTEnv",
14+
max_episode_steps=300,
15+
kwargs={"obs_type": "state", "randomize_goal": True},
816
)

gym_pusht/envs/pusht.py

+23-5
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ def __init__(
144144
observation_height=96,
145145
visualization_width=680,
146146
visualization_height=680,
147+
randomize_goal=False,
147148
):
148149
super().__init__()
149150
# Observations
@@ -167,6 +168,13 @@ def __init__(
167168
self.block_cog = block_cog
168169
self.damping = damping
169170

171+
# Randomization
172+
self.randomize_goal = randomize_goal
173+
# Safe margins from walls for positioning objects
174+
self.margin = 100 # Margin from walls to avoid spawning too close to edges
175+
self.min_pos = np.array([self.margin, self.margin])
176+
self.max_pos = np.array([512 - self.margin, 512 - self.margin])
177+
170178
# If human-rendering is used, `self.window` will be a reference
171179
# to the window that we draw to. `self.clock` will be a clock that is used
172180
# to ensure that the environment is rendered at the correct framerate in
@@ -269,18 +277,27 @@ def reset(self, seed=None, options=None):
269277
super().reset(seed=seed)
270278
self._setup()
271279

280+
# Randomize goal if enabled
281+
if self.randomize_goal:
282+
# Randomize goal position and orientation
283+
goal_x = self.np_random.uniform(self.min_pos[0], self.max_pos[0])
284+
goal_y = self.np_random.uniform(self.min_pos[1], self.max_pos[1])
285+
goal_theta = self.np_random.uniform(0, 2 * np.pi)
286+
self.goal_pose = np.array([goal_x, goal_y, goal_theta])
287+
288+
# Handle state reset
272289
if options is not None and options.get("reset_to_state") is not None:
273290
state = np.array(options.get("reset_to_state"))
274291
else:
275292
# state = self.np_random.uniform(low=[50, 50, 100, 100, -np.pi], high=[450, 450, 400, 400, np.pi])
276293
rs = np.random.RandomState(seed=seed)
277294
state = np.array(
278295
[
279-
rs.randint(50, 450),
280-
rs.randint(50, 450),
281-
rs.randint(100, 400),
282-
rs.randint(100, 400),
283-
rs.randn() * 2 * np.pi - np.pi,
296+
self.np_random.uniform(self.min_pos[0], self.max_pos[0]), # agent_x
297+
self.np_random.uniform(self.min_pos[1], self.max_pos[1]), # agent_y
298+
self.np_random.uniform(self.min_pos[0], self.max_pos[0]), # block_x
299+
self.np_random.uniform(self.min_pos[1], self.max_pos[1]), # block_y
300+
self.np_random.uniform(0, 2 * np.pi), # block_angle
284301
],
285302
# dtype=np.float64
286303
)
@@ -446,6 +463,7 @@ def _setup(self):
446463
# Add agent, block, and goal zone
447464
self.agent = self.add_circle(self.space, (256, 400), 15)
448465
self.block, self._block_shapes = self.add_tee(self.space, (256, 300), 0)
466+
# Default goal pose that will be used if randomization is disabled
449467
self.goal_pose = np.array([256, 256, np.pi / 4]) # x, y, theta (in radians)
450468
if self.block_cog is not None:
451469
self.block.center_of_gravity = self.block_cog

0 commit comments

Comments
 (0)