Skip to content

Commit d24b94c

Browse files
ankkhediapiyushghai
authored andcommitted
[MXNET-637] Multidimensional LSTM example for MXNetR (apache#12664)
* added R LSTM examples * added tutorial to whitelist * fix encoding * added seed and fixed few formatting issues * addressed PR comments * formatting fixes' * nit fixes * fix epochs * fixed tutorial link
1 parent e360db2 commit d24b94c

File tree

3 files changed

+605
-0
lines changed

3 files changed

+605
-0
lines changed

R-package/vignettes/MultidimLstm.Rmd

+302
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
LSTM time series example
2+
=============================================
3+
4+
This tutorial shows how to use an LSTM model with multivariate data, and generate predictions from it. For demonstration purposes, we used an open source [pollution data](https://archive.ics.uci.edu/ml/datasets/Beijing+PM2.5+Data).
5+
The tutorial is an illustration of how to use LSTM models with MXNet-R. We are forecasting the air pollution with data recorded at the US embassy in Beijing, China for five years.
6+
7+
Dataset Attribution:
8+
"PM2.5 data of US Embassy in Beijing"
9+
We want to predict pollution levels(PM2.5 concentration) in the city given the above dataset.
10+
11+
```r
12+
Dataset description:
13+
No: row number
14+
year: year of data in this row
15+
month: month of data in this row
16+
day: day of data in this row
17+
hour: hour of data in this row
18+
pm2.5: PM2.5 concentration
19+
DEWP: Dew Point
20+
TEMP: Temperature
21+
PRES: Pressure
22+
cbwd: Combined wind direction
23+
Iws: Cumulated wind speed
24+
Is: Cumulated hours of snow
25+
Ir: Cumulated hours of rain
26+
```
27+
28+
We use past PM2.5 concentration, dew point, temperature, pressure, wind speed, snow and rain to predict
29+
PM2.5 concentration levels.
30+
31+
Load and pre-process the data
32+
---------
33+
The first step is to load in the data and preprocess it. It is assumed that the data has been downloaded in a .csv file: data.csv from the [pollution dataset](https://archive.ics.uci.edu/ml/datasets/Beijing+PM2.5+Data)
34+
35+
```r
36+
## Loading required packages
37+
library("readr")
38+
library("dplyr")
39+
library("mxnet")
40+
library("abind")
41+
```
42+
43+
44+
45+
```r
46+
## Preprocessing steps
47+
Data <- read.csv(file = "/Users/khedia/Downloads/data.csv",
48+
header = TRUE,
49+
sep = ",")
50+
51+
## Extracting specific features from the dataset as variables for time series We extract
52+
## pollution, temperature, pressue, windspeed, snowfall and rainfall information from dataset
53+
df <- data.frame(Data$pm2.5,
54+
Data$DEWP,
55+
Data$TEMP,
56+
Data$PRES,
57+
Data$Iws,
58+
Data$Is,
59+
Data$Ir)
60+
df[is.na(df)] <- 0
61+
62+
## Now we normalise each of the feature set to a range(0,1)
63+
df <- matrix(as.matrix(df),
64+
ncol = ncol(df),
65+
dimnames = NULL)
66+
67+
rangenorm <- function(x) {
68+
(x - min(x))/(max(x) - min(x))
69+
}
70+
df <- apply(df, 2, rangenorm)
71+
df <- t(df)
72+
```
73+
For using multidimesional data with MXNet-R, we need to convert training data to the form
74+
(n_dim x seq_len x num_samples). For one-to-one RNN flavours labels should be of the form (seq_len x num_samples) while for many-to-one flavour, the labels should be of the form (1 x num_samples). Please note that MXNet-R currently supports only these two flavours of RNN.
75+
We have used n_dim = 7, seq_len = 100, and num_samples = 430 because the dataset has 430 samples, each the length of 100 timestamps, we have seven time series as input features so each input has dimesnion of seven at each time step.
76+
77+
78+
```r
79+
n_dim <- 7
80+
seq_len <- 100
81+
num_samples <- 430
82+
83+
## extract only required data from dataset
84+
trX <- df[1:n_dim, 25:(24 + (seq_len * num_samples))]
85+
86+
## the label data(next PM2.5 concentration) should be one time step
87+
## ahead of the current PM2.5 concentration
88+
trY <- df[1, 26:(25 + (seq_len * num_samples))]
89+
90+
## reshape the matrices in the format acceptable by MXNetR RNNs
91+
trainX <- trX
92+
dim(trainX) <- c(n_dim, seq_len, num_samples)
93+
trainY <- trY
94+
dim(trainY) <- c(seq_len, num_samples)
95+
```
96+
97+
98+
99+
Defining and training the network
100+
---------
101+
102+
```r
103+
batch.size <- 32
104+
105+
# take first 300 samples for training - remaining 100 for evaluation
106+
train_ids <- 1:300
107+
eval_ids <- 301:400
108+
109+
## The number of samples used for training and evaluation is arbitrary. I have kept aside few
110+
## samples for testing purposes create dataiterators
111+
train.data <- mx.io.arrayiter(data = trainX[, , train_ids, drop = F],
112+
label = trainY[, train_ids],
113+
batch.size = batch.size, shuffle = TRUE)
114+
115+
eval.data <- mx.io.arrayiter(data = trainX[, , eval_ids, drop = F],
116+
label = trainY[, eval_ids],
117+
batch.size = batch.size, shuffle = FALSE)
118+
119+
## Create the symbol for RNN
120+
symbol <- rnn.graph(num_rnn_layer = 1,
121+
num_hidden = 5,
122+
input_size = NULL,
123+
num_embed = NULL,
124+
num_decode = 1,
125+
masking = F,
126+
loss_output = "linear",
127+
dropout = 0.2,
128+
ignore_label = -1,
129+
cell_type = "lstm",
130+
output_last_state = T,
131+
config = "one-to-one")
132+
133+
134+
135+
mx.metric.mse.seq <- mx.metric.custom("MSE", function(label, pred) {
136+
label = mx.nd.reshape(label, shape = -1)
137+
pred = mx.nd.reshape(pred, shape = -1)
138+
res <- mx.nd.mean(mx.nd.square(label - pred))
139+
return(as.array(res))
140+
})
141+
142+
143+
144+
ctx <- mx.cpu()
145+
146+
initializer <- mx.init.Xavier(rnd_type = "gaussian",
147+
factor_type = "avg",
148+
magnitude = 3)
149+
150+
optimizer <- mx.opt.create("adadelta",
151+
rho = 0.9,
152+
eps = 1e-05,
153+
wd = 1e-06,
154+
clip_gradient = 1,
155+
rescale.grad = 1/batch.size)
156+
157+
logger <- mx.metric.logger()
158+
epoch.end.callback <- mx.callback.log.train.metric(period = 10,
159+
logger = logger)
160+
161+
## train the network
162+
system.time(model <- mx.model.buckets(symbol = symbol,
163+
train.data = train.data,
164+
eval.data = eval.data,
165+
num.round = 100,
166+
ctx = ctx,
167+
verbose = TRUE,
168+
metric = mx.metric.mse.seq,
169+
initializer = initializer,
170+
optimizer = optimizer,
171+
batch.end.callback = NULL,
172+
epoch.end.callback = epoch.end.callback))
173+
```
174+
Output:
175+
```
176+
Start training with 1 devices
177+
[1] Train-MSE=0.197570244409144
178+
[1] Validation-MSE=0.0153861071448773
179+
[2] Train-MSE=0.0152517843060195
180+
[2] Validation-MSE=0.0128299412317574
181+
[3] Train-MSE=0.0124418652616441
182+
[3] Validation-MSE=0.010827143676579
183+
[4] Train-MSE=0.0105128229130059
184+
[4] Validation-MSE=0.00940261723008007
185+
[5] Train-MSE=0.00914482437074184
186+
[5] Validation-MSE=0.00830172537826002
187+
[6] Train-MSE=0.00813581114634871
188+
[6] Validation-MSE=0.00747016374953091
189+
[7] Train-MSE=0.00735094994306564
190+
[7] Validation-MSE=0.00679832429159433
191+
[8] Train-MSE=0.00672049634158611
192+
[8] Validation-MSE=0.00623159145470709
193+
[9] Train-MSE=0.00620287149213254
194+
[9] Validation-MSE=0.00577476259786636
195+
[10] Train-MSE=0.00577280316501856
196+
[10] Validation-MSE=0.00539038667920977
197+
..........
198+
..........
199+
[91] Train-MSE=0.00177705133100972
200+
[91] Validation-MSE=0.00154715491225943
201+
[92] Train-MSE=0.00177639147732407
202+
[92] Validation-MSE=0.00154592350008897
203+
[93] Train-MSE=0.00177577760769054
204+
[93] Validation-MSE=0.00154474508599378
205+
[94] Train-MSE=0.0017752077546902
206+
[94] Validation-MSE=0.0015436161775142
207+
[95] Train-MSE=0.00177468206966296
208+
[95] Validation-MSE=0.00154253660002723
209+
[96] Train-MSE=0.00177419915562496
210+
[96] Validation-MSE=0.00154150440357625
211+
[97] Train-MSE=0.0017737578949891
212+
[97] Validation-MSE=0.00154051734716631
213+
[98] Train-MSE=0.00177335749613121
214+
[98] Validation-MSE=0.00153957353904843
215+
[99] Train-MSE=0.00177299699280411
216+
[99] Validation-MSE=0.00153867155313492
217+
[100] Train-MSE=0.00177267640829086
218+
[100] Validation-MSE=0.00153781197150238
219+
220+
user system elapsed
221+
21.937 1.914 13.402
222+
```
223+
We can see how mean squared error varies with epochs below.
224+
225+
![png](https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/r/images/loss.png?raw=true)<!--notebook-skip-line-->
226+
227+
Inference on the network
228+
---------
229+
Now we have trained the network. Let's use it for inference.
230+
231+
```r
232+
## We extract the state symbols for RNN
233+
internals <- model$symbol$get.internals()
234+
sym_state <- internals$get.output(which(internals$outputs %in% "RNN_state"))
235+
sym_state_cell <- internals$get.output(which(internals$outputs %in% "RNN_state_cell"))
236+
sym_output <- internals$get.output(which(internals$outputs %in% "loss_output"))
237+
symbol <- mx.symbol.Group(sym_output, sym_state, sym_state_cell)
238+
239+
## We will predict 100 timestamps for 401st sample (first sample from the test samples)
240+
pred_length <- 100
241+
predicted <- numeric()
242+
243+
## We pass the 400th sample through the network to get the weights and use it for predicting next
244+
## 100 time stamps.
245+
data <- mx.nd.array(trainX[, , 400, drop = F])
246+
label <- mx.nd.array(trainY[, 400, drop = F])
247+
248+
249+
## We create dataiterators for the input, please note that the label is required to create
250+
## iterator and will not be used in the inference. You can use dummy values too in the label.
251+
infer.data <- mx.io.arrayiter(data = data,
252+
label = label,
253+
batch.size = 1,
254+
shuffle = FALSE)
255+
256+
infer <- mx.infer.rnn.one(infer.data = infer.data,
257+
symbol = symbol,
258+
arg.params = model$arg.params,
259+
aux.params = model$aux.params,
260+
input.params = NULL,
261+
ctx = ctx)
262+
## Once we get the weights for the above time series, we try to predict the next 100 steps for
263+
## this time series, which is technically our 401st time series.
264+
265+
actual <- trainY[, 401]
266+
267+
## Now we iterate one by one to generate each of the next timestamp pollution values
268+
269+
for (i in 1:pred_length) {
270+
271+
data <- mx.nd.array(trainX[, i, 401, drop = F])
272+
label <- mx.nd.array(trainY[i, 401, drop = F])
273+
infer.data <- mx.io.arrayiter(data = data,
274+
label = label,
275+
batch.size = 1,
276+
shuffle = FALSE)
277+
## note that we use rnn state values from previous iterations here
278+
infer <- mx.infer.rnn.one(infer.data = infer.data,
279+
symbol = symbol,
280+
ctx = ctx,
281+
arg.params = model$arg.params,
282+
aux.params = model$aux.params,
283+
input.params = list(rnn.state = infer[[2]],
284+
rnn.state.cell = infer[[3]]))
285+
286+
pred <- infer[[1]]
287+
predicted <- c(predicted, as.numeric(as.array(pred)))
288+
289+
}
290+
291+
```
292+
Now predicted contains the predicted 100 values. We use ggplot to plot the actual and predicted values as shown below.
293+
294+
![png](https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/r/images/sample_401.png?raw=true)<!--notebook-skip-line-->
295+
296+
We also repeated the above experiments to generate the next 100 samples to 301st time series and we got the following results.
297+
298+
![png](https://github.com/dmlc/web-data/blob/master/mxnet/doc/tutorials/r/images/sample_301.png?raw=true)<!--notebook-skip-line-->
299+
300+
The above tutorial is just for demonstration purposes and has not been tuned extensively for accuracy.
301+
302+
For more tutorials on MXNet-R, head on to [MXNet-R tutorials](https://mxnet.incubator.apache.org/tutorials/r/index.html)

0 commit comments

Comments
 (0)