-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlearner.lua
171 lines (152 loc) · 5.27 KB
/
learner.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
--
-- Copyright (c) 2016, Horizon Robotics, Inc.
-- All rights reserved.
--
-- This source code is licensed under the MIT license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Author: Yao Zhou, yao.zhou@hobot.cc
--
local learner = torch.class('deeprl.learner')
function learner:__init(config)
self.task = config.task
self.epoch = config.epoch
self.env_config = config.env_config
self.agent_config = config.agent_config
self.agent = deeprl.agent(self.agent_config)
if config.task == 'car' then
self.envir = deeprl.carenv(self.env_config)
else
self.envir = deeprl.envir(self.env_config)
end
self.epsilon = config.epsilon
self.screen = deeprl.screen(config.env_config)
end
function learner:run()
local score = 0
for i = 1, self.epoch do
-- init environment
local error = 0
self.envir:reset()
local game_over = false
-- init state
local cur_state = self.envir:observe()
while game_over ~= true do
local action
if math.randf() <= self.epsilon then
action = math.random(1, self.agent_config.n_actions)
else
-- forward
local q = self.agent.policy_net:forward(cur_state)
local max, idx = torch.max(q, 1)
action = idx[1]
end
if self.epsilon > 0.001 then
self.epsilon = self.epsilon * 0.999
end
local next_state, reward, go = self.envir:act(action)
game_over = go
if reward == 1 then score = score + 100 end
self.agent:remember({
input_state = cur_state,
action = action,
reward = reward,
next_state = next_state,
game_over = game_over,
})
-- self.screen:show(self.envir.state)
cur_state = next_state
-- batch training
local inputs, targets = self.agent:generate_batch()
error = error + self.agent:train(inputs, targets)
end
collectgarbage()
print(string.format('Epoch %d : error = %f : Score %d', i, error, score))
end
end
function learner:test(steps)
local score = 0
for i = 1, steps do
local game_over = false
local row, col, pos = self.envir:reset()
local cur_state = self.envir:observe()
while not game_over do
local q = self.agent.policy_net:forward(cur_state)
local max, idx = torch.max(q, 1)
local action = idx[1]
local next_state, reward, go, row, col, pos = self.envir:act(action)
cur_state = next_state
game_over = go
reward = reward > 0 and 1 or 0
score = score + reward * 100
self.screen:show(self.envir.state)
end
print(string.format('step %d, score is %d', i, score))
end
os.exit()
end
function learner:run_car()
for i = 1, self.epoch do
-- init environment
local error = 0
local score = 0
self.envir:reset()
local game_over = false
-- init state
local cur_state = self.envir:observe()
while game_over ~= true do
local action
if math.randf() <= self.epsilon then
action = math.random(1, self.agent_config.n_actions)
else
-- forward
local q = self.agent.policy_net:forward(cur_state)
local max, idx = torch.max(q, 1)
action = idx[1]
end
if self.epsilon > 0.001 then
self.epsilon = self.epsilon * 0.999
end
local next_state, reward, go = self.envir:act(action)
game_over = go
if reward == 1 then score = score + 100 end
self.agent:remember({
input_state = cur_state,
action = action,
reward = reward,
next_state = next_state,
game_over = game_over,
})
-- self.screen:show(self.envir.state)
cur_state = next_state
-- batch training
local inputs, targets = self.agent:generate_batch()
error = error + self.agent:train(inputs, targets)
end
collectgarbage()
print(string.format('Epoch %d : error = %f : Score %d', i, error, score))
end
end
function learner:test_car(steps)
for i = 1, steps do
local score = 0
local game_over = false
self.envir:reset()
local cur_state = self.envir:observe()
while not game_over do
local q = self.agent.policy_net:forward(cur_state)
local max, idx = torch.max(q, 1)
local action = idx[1]
local next_state, reward, go, positions = self.envir:act(action)
cur_state = next_state
game_over = go
reward = reward > 0 and 1 or 0
score = score + reward * 100
self.screen:show_car(positions)
end
collectgarbage()
print(string.format('step %d, score is %d', i, score))
end
os.exit()
end