Skip to content

Commit aa28f8d

Browse files
duckietownrl
1 parent a899bfd commit aa28f8d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+10287
-0
lines changed

.gitattributes

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
duckiesim/manual/dataset/expert_data_36591.parquet filter=lfs diff=lfs merge=lfs -text

duckietownrl/__init__.py

Whitespace-only changes.

duckietownrl/algorithms/__init__.py

Whitespace-only changes.

duckietownrl/algorithms/ddpg.py

+238
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import functools
2+
import operator
3+
4+
import numpy as np
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
9+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10+
11+
12+
# Implementation of Deep Deterministic Policy Gradients (DDPG)
13+
# Paper: https://arxiv.org/abs/1509.02971
14+
15+
16+
class ActorDense(nn.Module):
17+
def __init__(self, state_dim, action_dim, max_action):
18+
super(ActorDense, self).__init__()
19+
20+
state_dim = functools.reduce(operator.mul, state_dim, 1)
21+
22+
self.l1 = nn.Linear(state_dim, 400)
23+
self.l2 = nn.Linear(400, 300)
24+
self.l3 = nn.Linear(300, action_dim)
25+
26+
self.max_action = max_action
27+
28+
self.tanh = nn.Tanh()
29+
30+
def forward(self, x):
31+
x = F.relu(self.l1(x))
32+
x = F.relu(self.l2(x))
33+
x = self.max_action * self.tanh(self.l3(x))
34+
return x
35+
36+
37+
class ActorCNN(nn.Module):
38+
def __init__(self, action_dim, max_action):
39+
super(ActorCNN, self).__init__()
40+
41+
# ONLY TRU IN CASE OF DUCKIETOWN:
42+
flat_size = 32 * 9 * 14
43+
44+
self.lr = nn.LeakyReLU()
45+
self.tanh = nn.Tanh()
46+
self.sigm = nn.Sigmoid()
47+
48+
self.conv1 = nn.Conv2d(3, 32, 8, stride=2)
49+
self.conv2 = nn.Conv2d(32, 32, 4, stride=2)
50+
self.conv3 = nn.Conv2d(32, 32, 4, stride=2)
51+
self.conv4 = nn.Conv2d(32, 32, 4, stride=1)
52+
53+
self.bn1 = nn.BatchNorm2d(32)
54+
self.bn2 = nn.BatchNorm2d(32)
55+
self.bn3 = nn.BatchNorm2d(32)
56+
self.bn4 = nn.BatchNorm2d(32)
57+
58+
self.dropout = nn.Dropout(0.5)
59+
60+
self.lin1 = nn.Linear(flat_size, 512)
61+
self.lin2 = nn.Linear(512, action_dim)
62+
63+
self.max_action = max_action
64+
65+
def forward(self, x):
66+
x = self.bn1(self.lr(self.conv1(x)))
67+
x = self.bn2(self.lr(self.conv2(x)))
68+
x = self.bn3(self.lr(self.conv3(x)))
69+
x = self.bn4(self.lr(self.conv4(x)))
70+
x = x.view(x.size(0), -1) # flatten
71+
x = self.dropout(x)
72+
x = self.lr(self.lin1(x))
73+
74+
# this is the vanilla implementation
75+
# but we're using a slightly different one
76+
# x = self.max_action * self.tanh(self.lin2(x))
77+
78+
# because we don't want our duckie to go backwards
79+
x = self.lin2(x)
80+
x[:, 0] = self.max_action * self.sigm(x[:, 0]) # because we don't want the duckie to go backwards
81+
x[:, 1] = self.tanh(x[:, 1])
82+
83+
return x
84+
85+
86+
class CriticDense(nn.Module):
87+
def __init__(self, state_dim, action_dim):
88+
super(CriticDense, self).__init__()
89+
90+
state_dim = functools.reduce(operator.mul, state_dim, 1)
91+
92+
self.l1 = nn.Linear(state_dim, 400)
93+
self.l2 = nn.Linear(400 + action_dim, 300)
94+
self.l3 = nn.Linear(300, 1)
95+
96+
def forward(self, x, u):
97+
x = F.relu(self.l1(x))
98+
x = F.relu(self.l2(torch.cat([x, u], 1)))
99+
x = self.l3(x)
100+
return x
101+
102+
103+
class CriticCNN(nn.Module):
104+
def __init__(self, action_dim):
105+
super(CriticCNN, self).__init__()
106+
107+
flat_size = 32 * 9 * 14
108+
109+
self.lr = nn.LeakyReLU()
110+
111+
self.conv1 = nn.Conv2d(3, 32, 8, stride=2)
112+
self.conv2 = nn.Conv2d(32, 32, 4, stride=2)
113+
self.conv3 = nn.Conv2d(32, 32, 4, stride=2)
114+
self.conv4 = nn.Conv2d(32, 32, 4, stride=1)
115+
116+
self.bn1 = nn.BatchNorm2d(32)
117+
self.bn2 = nn.BatchNorm2d(32)
118+
self.bn3 = nn.BatchNorm2d(32)
119+
self.bn4 = nn.BatchNorm2d(32)
120+
121+
self.dropout = nn.Dropout(0.5)
122+
123+
self.lin1 = nn.Linear(flat_size, 256)
124+
self.lin2 = nn.Linear(256 + action_dim, 128)
125+
self.lin3 = nn.Linear(128, 1)
126+
127+
def forward(self, states, actions):
128+
x = self.bn1(self.lr(self.conv1(states)))
129+
x = self.bn2(self.lr(self.conv2(x)))
130+
x = self.bn3(self.lr(self.conv3(x)))
131+
x = self.bn4(self.lr(self.conv4(x)))
132+
x = x.view(x.size(0), -1) # flatten
133+
x = self.lr(self.lin1(x))
134+
x = self.lr(self.lin2(torch.cat([x, actions], 1))) # c
135+
x = self.lin3(x)
136+
137+
return x
138+
139+
140+
class DDPG(object):
141+
def __init__(self, state_dim, action_dim, max_action, net_type):
142+
super(DDPG, self).__init__()
143+
print("Starting DDPG init")
144+
assert net_type in ["cnn", "dense"]
145+
146+
self.state_dim = state_dim
147+
148+
if net_type == "dense":
149+
self.flat = True
150+
self.actor = ActorDense(state_dim, action_dim, max_action).to(device)
151+
self.actor_target = ActorDense(state_dim, action_dim, max_action).to(device)
152+
else:
153+
self.flat = False
154+
self.actor = ActorCNN(action_dim, max_action).to(device)
155+
self.actor_target = ActorCNN(action_dim, max_action).to(device)
156+
157+
print("Initialized Actor")
158+
self.actor_target.load_state_dict(self.actor.state_dict())
159+
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(), lr=1e-4)
160+
print("Initialized Target+Opt [Actor]")
161+
if net_type == "dense":
162+
self.critic = CriticDense(state_dim, action_dim).to(device)
163+
self.critic_target = CriticDense(state_dim, action_dim).to(device)
164+
else:
165+
self.critic = CriticCNN(action_dim).to(device)
166+
self.critic_target = CriticCNN(action_dim).to(device)
167+
print("Initialized Critic")
168+
self.critic_target.load_state_dict(self.critic.state_dict())
169+
self.critic_optimizer = torch.optim.Adam(self.critic.parameters())
170+
print("Initialized Target+Opt [Critic]")
171+
172+
def predict(self, state):
173+
174+
# just making sure the state has the correct format, otherwise the prediction doesn't work
175+
assert state.shape[0] == 3
176+
177+
if self.flat:
178+
state = torch.FloatTensor(state.reshape(1, -1)).to(device)
179+
else:
180+
state = torch.FloatTensor(np.expand_dims(state, axis=0)).to(device)
181+
return self.actor(state).cpu().data.numpy().flatten()
182+
183+
def train(self, replay_buffer, iterations, batch_size=64, discount=0.99, tau=0.001):
184+
185+
for it in range(iterations):
186+
187+
# Sample replay buffer
188+
sample = replay_buffer.sample(batch_size, flat=self.flat)
189+
state = torch.FloatTensor(sample["state"]).to(device)
190+
action = torch.FloatTensor(sample["action"]).to(device)
191+
next_state = torch.FloatTensor(sample["next_state"]).to(device)
192+
done = torch.FloatTensor(1 - sample["done"]).to(device)
193+
reward = torch.FloatTensor(sample["reward"]).to(device)
194+
195+
# Compute the target Q value
196+
target_Q = self.critic_target(next_state, self.actor_target(next_state))
197+
target_Q = reward + (done * discount * target_Q).detach()
198+
199+
# Get current Q estimate
200+
current_Q = self.critic(state, action)
201+
202+
# Compute critic loss
203+
critic_loss = F.mse_loss(current_Q, target_Q)
204+
205+
# Optimize the critic
206+
self.critic_optimizer.zero_grad()
207+
critic_loss.backward()
208+
self.critic_optimizer.step()
209+
210+
# Compute actor loss
211+
actor_loss = -self.critic(state, self.actor(state)).mean()
212+
213+
# Optimize the actor
214+
self.actor_optimizer.zero_grad()
215+
actor_loss.backward()
216+
self.actor_optimizer.step()
217+
218+
# Update the frozen target models
219+
for param, target_param in zip(self.critic.parameters(), self.critic_target.parameters()):
220+
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
221+
222+
for param, target_param in zip(self.actor.parameters(), self.actor_target.parameters()):
223+
target_param.data.copy_(tau * param.data + (1 - tau) * target_param.data)
224+
225+
def save(self, filename, directory):
226+
print("Saving to {}/{}_[actor|critic].pth".format(directory, filename))
227+
torch.save(self.actor.state_dict(), "{}/{}_actor.pth".format(directory, filename))
228+
print("Saved Actor")
229+
torch.save(self.critic.state_dict(), "{}/{}_critic.pth".format(directory, filename))
230+
print("Saved Critic")
231+
232+
def load(self, filename, directory):
233+
self.actor.load_state_dict(
234+
torch.load("{}/{}_actor.pth".format(directory, filename), map_location=device)
235+
)
236+
self.critic.load_state_dict(
237+
torch.load("{}/{}_critic.pth".format(directory, filename), map_location=device)
238+
)

0 commit comments

Comments
 (0)