1
+ import numpy as np
2
+ import matplotlib .pyplot as plt
3
+ import matplotlib .patches as mpatches
4
+
5
+ class GridWorld :
6
+ def __init__ (self , nrows : int , ncols : int , start_state : tuple , goal_state : tuple ):
7
+ """
8
+ Initializes a new instance of the GridWorld class.
9
+
10
+ Args:
11
+ nrows (int): The number of rows in the grid.
12
+ ncols (int): The number of columns in the grid.
13
+ start_state (tuple): The starting state of the grid world.
14
+ goal_state (tuple): The goal state of the grid world.
15
+
16
+ Returns:
17
+ None
18
+ """
19
+ self .num_rows = nrows
20
+ self .num_cols = ncols
21
+ self .start_state = start_state
22
+ self .goal_state = goal_state
23
+ self .obstacles = None
24
+ self .reward_step = None
25
+ self .reward_goal = None
26
+ self .gamma = 1 #default discount factor is 1
27
+ self .actions = ["up" , "down" , "left" , "right" ]
28
+ self .num_actions = 4
29
+ self .all_states = [(i , j ) for i in range (self .num_rows ) for j in range (self .num_cols )]
30
+ self .num_states = self .num_cols * self .num_rows
31
+
32
+ def add_obstacles (self , obstacles : list ):
33
+ """
34
+ Adds obstacles to the grid world.
35
+
36
+ Parameters:
37
+ obstacles (list): A list of obstacle coordinates in the grid.
38
+
39
+ Returns:
40
+ None
41
+ """
42
+ self .obstacles = obstacles
43
+
44
+ for obstacle in obstacles :
45
+ self .all_states .remove (obstacle )
46
+
47
+ self .num_states = len (self .all_states )
48
+
49
+ def is_valid_state (self , state : tuple ) -> bool :
50
+ """
51
+ Check if the given state is valid in the grid world.
52
+
53
+ Args:
54
+ state (tuple): The state to check in the form of (row, col).
55
+
56
+ Returns:
57
+ bool: True if the state is valid, False otherwise.
58
+
59
+ Raises:
60
+ None
61
+
62
+ Notes:
63
+ - The state is considered valid if it is within the grid boundaries and not an obstacle.
64
+ - The grid boundaries are defined by the attributes `num_rows` and `num_cols`.
65
+ - The obstacles are defined by the attribute `obstacles`.
66
+ """
67
+ row , col = state
68
+ return 0 <= row < self .num_rows and 0 <= col < self .num_cols and state not in self .obstacles
69
+
70
+ def is_terminal (self , state : tuple ) -> bool :
71
+ """
72
+ Check if the provided state is the goal state.
73
+
74
+ Args:
75
+ state (tuple): The state to check for being the goal state.
76
+
77
+ Returns:
78
+ bool: True if the state is the goal state, False otherwise.
79
+ """
80
+ return state == self .goal_state
81
+
82
+ def add_rewards (self , reward_step : float , reward_goal : float ):
83
+ """
84
+ Set the reward values for the step and goal states.
85
+
86
+ Args:
87
+ reward_step (float): The reward value for each step.
88
+ reward_goal (float): The reward value for reaching the goal state.
89
+
90
+ Returns:
91
+ None
92
+ """
93
+ self .reward_step = reward_step
94
+ self .reward_goal = reward_goal
95
+
96
+ def add_discout_factor (self , gamma : float ):
97
+ """
98
+ Set the discount factor for the GridWorld.
99
+
100
+ Args:
101
+ gamma (float): The discount factor to be set.
102
+
103
+ Returns:
104
+ None
105
+ """
106
+ self .gamma = gamma
107
+
108
+ def get_next_state (self , state : tuple , action : str ) -> tuple :
109
+ """
110
+ Calculate the next state based on the current state and action.
111
+
112
+ Args:
113
+ state (tuple): The current state of the grid world in the form of (row, col).
114
+ action (str): The action to take in the grid world. Can be "up", "down", "left", or "right".
115
+
116
+ Returns:
117
+ tuple: The next state of the grid world
118
+ """
119
+ state = list (state )
120
+ next_state = state .copy ()
121
+
122
+ if action == "up" :
123
+ next_state [0 ] -= 1
124
+ elif action == "down" :
125
+ next_state [0 ] += 1
126
+ elif action == "left" :
127
+ next_state [1 ] -= 1
128
+ elif action == "right" :
129
+ next_state [1 ] += 1
130
+
131
+ next_state = tuple (next_state )
132
+ state = tuple (state )
133
+ if self .is_valid_state (next_state ):
134
+ return next_state
135
+ else :
136
+ return state # stay in the same state
137
+
138
+ def get_reward (self , state : tuple ) -> float :
139
+ """
140
+ Returns the reward for the given state.
141
+
142
+ Parameters:
143
+ state (tuple): The current state of the grid world in the form of (row, col).
144
+
145
+ Returns:
146
+ float: The reward value for the given state. If the state is a goal state, the reward is the value of self.reward_goal. If the state is not a goal state, the reward is the value of self.reward_step.
147
+
148
+ Raises:
149
+ None
150
+ """
151
+
152
+ if self .is_valid_state (state ):
153
+ return self .reward_goal if self .is_terminal (state ) else self .reward_step
154
+
155
+
156
+ def dynamics (self ):
157
+ """
158
+ Initialize the environment dynamics for the GridWorld.
159
+
160
+ The environment dynamics are defined by the `P` dictionary. `P` is a nested dictionary that maps each state in `all_states` to a dictionary
161
+ that maps each action in `actions` to a list containing two elements: the next state after taking that action in that state,
162
+ and the reward obtained after taking that action in that state.
163
+
164
+ Parameters:
165
+ None
166
+
167
+ Returns:
168
+ None
169
+ """
170
+
171
+ # Initialize the environment's dyanmics
172
+ # self.V = {state: 0 for state in self.all_states}
173
+ # self.Q = {state: {action: 0 for action in self.actions} for state in self.all_states}
174
+ self .P = {state : {action : [self .get_next_state (state , action ), self .get_reward (self .get_next_state (state , action )), 1 ] for action in self .actions } for state in self .all_states }
175
+
176
+ def random_policy (self ):
177
+ """
178
+ Generate a random policy for the grid world.
179
+
180
+ Returns:
181
+ dict: A dictionary representing the random policy.
182
+ The keys are the states in the grid world, and the values are dictionaries that map each action to a probability of taking that action.
183
+ The probabilities are all equal to 0.25.
184
+ """
185
+
186
+ return {state : {action : 0.25 for action in self .actions } for state in self .all_states }
187
+
188
+ def visualize_gridWorld (self ):
189
+ for i in range (self .num_rows ):
190
+ for j in range (self .num_cols ):
191
+ if [i , j ] in self .start_state .tolist ():
192
+ print ("S" , end = " " )
193
+ elif [i , j ] in self .goal_state .tolist ():
194
+ print ("G" , end = " " )
195
+ elif [i , j ] in self .obstacles .tolist ():
196
+ print ("#" , end = " " )
197
+ else :
198
+ print ("." , end = " " )
199
+ print ()
200
+
201
+
202
+ # TODO: Work on the plot_grid and update_values functions
203
+ class GridWorldVisualization :
204
+ def __init__ (self , grid_world : GridWorld ):
205
+ """
206
+ Initializes a new instance of the GridWorldVisualization class.
207
+
208
+ Args:
209
+ grid_world (GridWorld): An instance of the GridWorld class.
210
+
211
+ Returns:
212
+ None
213
+ """
214
+ self .grid_world = grid_world
215
+ self .grid = np .zeros ((grid_world .num_rows , grid_world .num_cols ))
216
+
217
+ def plot_grid_with_arrows (self , grid_world , grid_dict ):
218
+ def plot_arrow (coord , direction ):
219
+ dx , dy = {'up' : (0 , - 0.4 ), 'down' : (0 , 0.4 ), 'left' : (- 0.4 , 0 ), 'right' : (0.4 , 0 )}[direction ]
220
+ plt .arrow (coord [1 ], coord [0 ], dx , dy , head_width = 0.1 , head_length = 0.1 , fc = 'k' , ec = 'k' )
221
+
222
+ grid = np .zeros ((grid_world .num_rows , grid_world .num_cols ))
223
+
224
+ for coord , actions in grid_dict .items ():
225
+ if actions :
226
+ for action , val in actions .items ():
227
+ if val == 1 :
228
+ grid [coord [0 ], coord [1 ]] = 1
229
+ plot_arrow (coord , action )
230
+ else :
231
+ grid [coord [0 ], coord [1 ]] = 1 # Goal
232
+
233
+ plt .imshow (grid , cmap = 'gray' , origin = 'upper' )
234
+
235
+ # Highlighting start state and goal state
236
+ plt .scatter (grid_world .start_state [1 ], grid_world .start_state [0 ], color = 'green' , marker = 'o' , s = 100 , label = 'Start' )
237
+ plt .scatter (grid_world .goal_state [1 ], grid_world .goal_state [0 ], color = 'red' , marker = 'x' , s = 100 , label = 'Goal' )
238
+
239
+ plt .yticks (np .arange (- 0.5 , grid_world .num_rows + 0.5 , step = 1 ))
240
+ plt .xticks (np .arange (- 0.5 , grid_world .num_cols + 0.5 , step = 1 ))
241
+ plt .grid ()
242
+
243
+ plt .legend (bbox_to_anchor = (1.05 , 1 ), loc = 'upper left' , borderaxespad = 0.0 )
244
+ plt .show ()
245
+
246
+
247
+
248
+
249
+
250
+
251
+
0 commit comments