Skip to content

Commit 1c5440d

Browse files
author
Shubham Sanjay Kamble
committed
Finished making the grid_world environment. Implemented the classical dynamic programming techniques of policy iteration and value iteration.
0 parents  commit 1c5440d

13 files changed

+520
-0
lines changed

ReadMe.md

Whitespace-only changes.

env/__init__.py

Whitespace-only changes.

env/grid_world.py

+251
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,251 @@
1+
import numpy as np
2+
import matplotlib.pyplot as plt
3+
import matplotlib.patches as mpatches
4+
5+
class GridWorld:
6+
def __init__(self, nrows : int, ncols : int, start_state : tuple, goal_state : tuple):
7+
"""
8+
Initializes a new instance of the GridWorld class.
9+
10+
Args:
11+
nrows (int): The number of rows in the grid.
12+
ncols (int): The number of columns in the grid.
13+
start_state (tuple): The starting state of the grid world.
14+
goal_state (tuple): The goal state of the grid world.
15+
16+
Returns:
17+
None
18+
"""
19+
self.num_rows = nrows
20+
self.num_cols = ncols
21+
self.start_state = start_state
22+
self.goal_state = goal_state
23+
self.obstacles = None
24+
self.reward_step = None
25+
self.reward_goal = None
26+
self.gamma = 1 #default discount factor is 1
27+
self.actions = ["up", "down", "left", "right"]
28+
self.num_actions = 4
29+
self.all_states = [(i, j) for i in range(self.num_rows) for j in range(self.num_cols)]
30+
self.num_states = self.num_cols * self.num_rows
31+
32+
def add_obstacles(self, obstacles : list):
33+
"""
34+
Adds obstacles to the grid world.
35+
36+
Parameters:
37+
obstacles (list): A list of obstacle coordinates in the grid.
38+
39+
Returns:
40+
None
41+
"""
42+
self.obstacles = obstacles
43+
44+
for obstacle in obstacles:
45+
self.all_states.remove(obstacle)
46+
47+
self.num_states = len(self.all_states)
48+
49+
def is_valid_state(self, state: tuple) -> bool:
50+
"""
51+
Check if the given state is valid in the grid world.
52+
53+
Args:
54+
state (tuple): The state to check in the form of (row, col).
55+
56+
Returns:
57+
bool: True if the state is valid, False otherwise.
58+
59+
Raises:
60+
None
61+
62+
Notes:
63+
- The state is considered valid if it is within the grid boundaries and not an obstacle.
64+
- The grid boundaries are defined by the attributes `num_rows` and `num_cols`.
65+
- The obstacles are defined by the attribute `obstacles`.
66+
"""
67+
row, col = state
68+
return 0 <= row < self.num_rows and 0 <= col < self.num_cols and state not in self.obstacles
69+
70+
def is_terminal(self, state: tuple) -> bool:
71+
"""
72+
Check if the provided state is the goal state.
73+
74+
Args:
75+
state (tuple): The state to check for being the goal state.
76+
77+
Returns:
78+
bool: True if the state is the goal state, False otherwise.
79+
"""
80+
return state == self.goal_state
81+
82+
def add_rewards(self, reward_step: float, reward_goal: float):
83+
"""
84+
Set the reward values for the step and goal states.
85+
86+
Args:
87+
reward_step (float): The reward value for each step.
88+
reward_goal (float): The reward value for reaching the goal state.
89+
90+
Returns:
91+
None
92+
"""
93+
self.reward_step = reward_step
94+
self.reward_goal = reward_goal
95+
96+
def add_discout_factor(self, gamma: float):
97+
"""
98+
Set the discount factor for the GridWorld.
99+
100+
Args:
101+
gamma (float): The discount factor to be set.
102+
103+
Returns:
104+
None
105+
"""
106+
self.gamma = gamma
107+
108+
def get_next_state(self, state: tuple, action: str) -> tuple:
109+
"""
110+
Calculate the next state based on the current state and action.
111+
112+
Args:
113+
state (tuple): The current state of the grid world in the form of (row, col).
114+
action (str): The action to take in the grid world. Can be "up", "down", "left", or "right".
115+
116+
Returns:
117+
tuple: The next state of the grid world
118+
"""
119+
state = list(state)
120+
next_state = state.copy()
121+
122+
if action == "up":
123+
next_state[0] -= 1
124+
elif action == "down":
125+
next_state[0] += 1
126+
elif action == "left":
127+
next_state[1] -= 1
128+
elif action == "right":
129+
next_state[1] += 1
130+
131+
next_state = tuple(next_state)
132+
state = tuple(state)
133+
if self.is_valid_state(next_state):
134+
return next_state
135+
else:
136+
return state # stay in the same state
137+
138+
def get_reward(self, state: tuple) -> float:
139+
"""
140+
Returns the reward for the given state.
141+
142+
Parameters:
143+
state (tuple): The current state of the grid world in the form of (row, col).
144+
145+
Returns:
146+
float: The reward value for the given state. If the state is a goal state, the reward is the value of self.reward_goal. If the state is not a goal state, the reward is the value of self.reward_step.
147+
148+
Raises:
149+
None
150+
"""
151+
152+
if self.is_valid_state(state):
153+
return self.reward_goal if self.is_terminal(state) else self.reward_step
154+
155+
156+
def dynamics(self):
157+
"""
158+
Initialize the environment dynamics for the GridWorld.
159+
160+
The environment dynamics are defined by the `P` dictionary. `P` is a nested dictionary that maps each state in `all_states` to a dictionary
161+
that maps each action in `actions` to a list containing two elements: the next state after taking that action in that state,
162+
and the reward obtained after taking that action in that state.
163+
164+
Parameters:
165+
None
166+
167+
Returns:
168+
None
169+
"""
170+
171+
# Initialize the environment's dyanmics
172+
# self.V = {state: 0 for state in self.all_states}
173+
# self.Q = {state: {action: 0 for action in self.actions} for state in self.all_states}
174+
self.P = {state: {action: [self.get_next_state(state, action), self.get_reward(self.get_next_state(state, action)), 1] for action in self.actions} for state in self.all_states}
175+
176+
def random_policy(self):
177+
"""
178+
Generate a random policy for the grid world.
179+
180+
Returns:
181+
dict: A dictionary representing the random policy.
182+
The keys are the states in the grid world, and the values are dictionaries that map each action to a probability of taking that action.
183+
The probabilities are all equal to 0.25.
184+
"""
185+
186+
return {state: {action: 0.25 for action in self.actions} for state in self.all_states}
187+
188+
def visualize_gridWorld(self):
189+
for i in range(self.num_rows):
190+
for j in range(self.num_cols):
191+
if [i, j] in self.start_state.tolist():
192+
print("S", end=" ")
193+
elif [i, j] in self.goal_state.tolist():
194+
print("G", end=" ")
195+
elif [i, j] in self.obstacles.tolist():
196+
print("#", end=" ")
197+
else:
198+
print(".", end=" ")
199+
print()
200+
201+
202+
# TODO: Work on the plot_grid and update_values functions
203+
class GridWorldVisualization:
204+
def __init__(self, grid_world: GridWorld):
205+
"""
206+
Initializes a new instance of the GridWorldVisualization class.
207+
208+
Args:
209+
grid_world (GridWorld): An instance of the GridWorld class.
210+
211+
Returns:
212+
None
213+
"""
214+
self.grid_world = grid_world
215+
self.grid = np.zeros((grid_world.num_rows, grid_world.num_cols))
216+
217+
def plot_grid_with_arrows(self, grid_world, grid_dict):
218+
def plot_arrow(coord, direction):
219+
dx, dy = {'up': (0, -0.4), 'down': (0, 0.4), 'left': (-0.4, 0), 'right': (0.4, 0)}[direction]
220+
plt.arrow(coord[1], coord[0], dx, dy, head_width=0.1, head_length=0.1, fc='k', ec='k')
221+
222+
grid = np.zeros((grid_world.num_rows, grid_world.num_cols))
223+
224+
for coord, actions in grid_dict.items():
225+
if actions:
226+
for action, val in actions.items():
227+
if val == 1:
228+
grid[coord[0], coord[1]] = 1
229+
plot_arrow(coord, action)
230+
else:
231+
grid[coord[0], coord[1]] = 1 # Goal
232+
233+
plt.imshow(grid, cmap='gray', origin='upper')
234+
235+
# Highlighting start state and goal state
236+
plt.scatter(grid_world.start_state[1], grid_world.start_state[0], color='green', marker='o', s=100, label='Start')
237+
plt.scatter(grid_world.goal_state[1], grid_world.goal_state[0], color='red', marker='x', s=100, label='Goal')
238+
239+
plt.yticks(np.arange(-0.5, grid_world.num_rows+0.5, step=1))
240+
plt.xticks(np.arange(-0.5, grid_world.num_cols+0.5, step=1))
241+
plt.grid()
242+
243+
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', borderaxespad=0.0)
244+
plt.show()
245+
246+
247+
248+
249+
250+
251+

environment.yml

+8
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
name: rl-projects
2+
channels:
3+
- defaults
4+
- conda-forge
5+
6+
dependencies:
7+
- python=3.10
8+
- pip-tools

plot_gridworld_example.py

+25
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import numpy as np
2+
from env.grid_world import GridWorld, GridWorldVisualization
3+
4+
if __name__ == '__main__':
5+
6+
nrows = 3
7+
ncols = 4
8+
start_state = (0, 0)
9+
goal_state = (2, 3)
10+
obstacles = [(1, 1),(0, 2)]
11+
grid_world = GridWorld(nrows, ncols, start_state, goal_state)
12+
grid_world.add_obstacles(obstacles)
13+
grid_world.add_rewards(-1.0, 100.0)
14+
grid_world.dynamics()
15+
16+
17+
visualization = GridWorldVisualization(grid_world)
18+
19+
20+
21+
22+
23+
24+
25+

policy_iteration_gridworld.py

+43
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import numpy as np
2+
from env.grid_world import GridWorld, GridWorldVisualization
3+
from tabular_methods.dynamic_programming import PolicyIteration
4+
5+
# simple grid
6+
# nrows = 4
7+
# ncols = 4
8+
# start_state = (0, 0)
9+
# goal_state = (3, 3)
10+
# obstacles = [(1, 1),(0, 2)]
11+
12+
# hard grid
13+
nrows = 20
14+
ncols = 20
15+
start_state = (0, 0)
16+
goal_state = (19, 19)
17+
obstacles = [(5, 5), (5, 6), (5, 7), (5, 8), (5, 9),
18+
(6, 9), (7, 9), (8, 9), (9, 9), (10, 9),
19+
(10, 8), (10, 7), (10, 6), (10, 5), (10, 4),
20+
(11, 4), (12, 4), (13, 4), (14, 4), (15, 4),
21+
(15, 5), (15, 6), (15, 7), (15, 8), (15, 9),
22+
(16, 9), (17, 9), (18, 9), (19, 9), (19, 10)]
23+
24+
grid_world = GridWorld(nrows, ncols, start_state, goal_state)
25+
grid_world.add_obstacles(obstacles)
26+
grid_world.add_rewards(-1.0, 100.0)
27+
grid_world.dynamics()
28+
policy = grid_world.random_policy()
29+
30+
# Policy Iteration:
31+
V, policy = PolicyIteration(grid_world, theta=0.0001)
32+
33+
# Visualization results:
34+
visualization = GridWorldVisualization(grid_world)
35+
visualization.plot_grid_with_arrows(grid_world, policy)
36+
37+
38+
39+
40+
41+
42+
43+

tabular_methods/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)