-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-637] Multidimensional LSTM example for MXNetR #12664
Changes from 3 commits
74022af
6f67502
611741a
9931e49
b5179c8
be32352
d4a025c
80301fc
cccf46f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,353 @@ | ||
LSTM time series example | ||
============================================= | ||
|
||
This tutorial shows how to use an LSTM model with multivariate data, and generate predictions from it. For demonstration purposes, we used an opensource pollution data. You can find the data on [GitHub](https://github.com/dmlc/web-data/tree/master/mxnet/tinyshakespeare). | ||
The tutorial is an illustration of how to use LSTM models with MXNetR. We are forecasting the air pollution with data recorded at the US embassy in Beijing, China for five years. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is MXNetR the name? I though of the bindings were "MXNet - {package}". So MXNet-R.... |
||
|
||
Dataset Attribution: | ||
"PM2.5 data of US Embassy in Beijing" (https://archive.ics.uci.edu/ml/datasets/Beijing+PM2.5+Data) | ||
We want to predict pollution levels(PM2.5 concentration) in the city given the above dataset. | ||
|
||
```r | ||
Dataset description: | ||
No: row number | ||
year: year of data in this row | ||
month: month of data in this row | ||
day: day of data in this row | ||
hour: hour of data in this row | ||
pm2.5: PM2.5 concentration | ||
DEWP: Dew Point | ||
TEMP: Temperature | ||
PRES: Pressure | ||
cbwd: Combined wind direction | ||
Iws: Cumulated wind speed | ||
Is: Cumulated hours of snow | ||
Ir: Cumulated hours of rain | ||
``` | ||
|
||
We use past PM2.5 concentration, dew point, temperature, pressure, wind speed, snow and rain to predict | ||
PM2.5 concentration levels | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. levels. |
||
|
||
Load and pre-process the Data | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. extra space after load and lower case data |
||
--------- | ||
Load in the data and preprocess it. It is assumed that the data has been downloaded in as csv file 'data.csv' locally. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The first step is to load in ... |
||
|
||
```r | ||
## Loading required packages | ||
library("readr") | ||
ankkhedia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
library("dplyr") | ||
library("mxnet") | ||
library("abind") | ||
``` | ||
|
||
|
||
|
||
```r | ||
## Preprocessing steps | ||
|
||
Data <- read.csv(file="data.csv", header=TRUE, sep=",") | ||
|
||
## Extracting specific features from the dataset as variables for time series | ||
## We extract pollution, temperature, pressue, windspeed, snowfall and rainfall information from dataset | ||
|
||
df<-data.frame(Data$pm2.5, Data$DEWP,Data$TEMP, Data$PRES, Data$Iws, Data$Is, Data$Ir) | ||
df[is.na(df)] <- 0 | ||
|
||
## Now we normalise each of the feature set to a range(0,1) | ||
df<-matrix(as.matrix(df),ncol=ncol(df),dimnames=NULL) | ||
rangenorm <- function(x){(x-min(x))/(max(x)-min(x))} | ||
df <- apply(df,2, rangenorm) | ||
df<-t(df) | ||
``` | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. looks like an extra leading space here |
||
For using multidimesional data with MXNetR. We need to convert training data to the form | ||
(n_dim x seq_len x num_samples) and label should be of the form (seq_len x num_samples) or (1 x num_samples) | ||
depending on the LSTM flavour to be used(one-to-one/ many-to-one). Please note that MXNetR currently supports only these two flavours of RNN. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kind of long statement.... break into two... |
||
We have used n_dim =7, seq_len = 100 and num_samples= 430. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can these parameters be part of the code. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. will do that |
||
|
||
```r | ||
## extract only required data from dataset | ||
trX<- df[1:n_dim, 25:(24+(seqlen* num_samples))] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. defined parameter is seq_len and used parameter is seqlen. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. good catch! |
||
## the label data(next PM2.5 concentraion) should be one time step ahead of the current PM2.5 concentration | ||
|
||
trY<- df[1,26:(25+(seqlen* num_samples))] | ||
## reshape the matrices in the format acceptable by MXNetR RNNs | ||
trainX<- trX | ||
dim(trainX) <- c(7, 100,430) | ||
trainY<- trY | ||
dim(trainY)<- c(100,430) | ||
|
||
``` | ||
|
||
|
||
|
||
Defining and training the network | ||
--------- | ||
|
||
```r | ||
batch.size = 32 | ||
# take first 300 samples for train - remaining 100 for evaluation | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. training |
||
train_ids <- 1:300 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I thought there were 430 samples... what you doing with the last 30? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. kept them for test samples. added |
||
eval_ids<- 301:400 | ||
|
||
## create dataiterators | ||
train.data <- mx.io.arrayiter(data = trainX[,,train_ids, drop = F], label = trainY[, train_ids], | ||
batch.size = batch.size, shuffle = TRUE) | ||
|
||
eval.data <- mx.io.arrayiter(data = trainX[,,eval_ids, drop = F], label = trainY[, eval_ids], | ||
batch.size = batch.size, shuffle = FALSE) | ||
|
||
## Create the symbol for RNN | ||
symbol <- rnn.graph(num_rnn_layer = 2, | ||
ankkhedia marked this conversation as resolved.
Show resolved
Hide resolved
|
||
num_hidden = 50, | ||
input_size = NULL, | ||
num_embed = NULL, | ||
num_decode = 1, | ||
masking = F, | ||
loss_output = "linear", | ||
dropout = 0.2, | ||
ignore_label = -1, | ||
cell_type = "lstm", | ||
output_last_state = T, | ||
config = "one-to-one") | ||
|
||
|
||
|
||
mx.metric.mse.seq <- mx.metric.custom("MSE", function(label, pred) { | ||
label = mx.nd.reshape(label, shape = -1) | ||
pred = mx.nd.reshape(pred, shape = -1) | ||
res <- mx.nd.mean(mx.nd.square(label-pred)) | ||
return(as.array(res)) | ||
}) | ||
|
||
|
||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note that you can change this for GPU usage? |
||
ctx <- mx.cpu() | ||
|
||
initializer <- mx.init.Xavier(rnd_type = "gaussian", | ||
factor_type = "avg", | ||
magnitude = 3) | ||
|
||
optimizer <- mx.opt.create("adadelta", rho = 0.9, eps = 1e-5, wd = 1e-6, | ||
clip_gradient = 1, rescale.grad = 1/batch.size) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. An example of the inconsistent indent style across the file. Needs more careful format or using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. addressed |
||
|
||
logger <- mx.metric.logger() | ||
epoch.end.callback <- mx.callback.log.train.metric(period = 10, logger = logger) | ||
|
||
## train the network | ||
system.time( | ||
model <- mx.model.buckets(symbol = symbol, | ||
train.data = train.data, | ||
eval.data = eval.data, | ||
num.round = 50, ctx = ctx, verbose = TRUE, | ||
metric = mx.metric.mse.seq, | ||
initializer = initializer, optimizer = optimizer, | ||
batch.end.callback = NULL, | ||
epoch.end.callback = epoch.end.callback) | ||
) | ||
|
||
``` | ||
Output: | ||
``` | ||
Start training with 1 devices | ||
[1] Train-MSE=0.0175756292417645 | ||
[1] Validation-MSE=0.0108831799589097 | ||
[2] Train-MSE=0.0116676720790565 | ||
[2] Validation-MSE=0.00835292390547693 | ||
[3] Train-MSE=0.0103536401875317 | ||
[3] Validation-MSE=0.00770004198420793 | ||
[4] Train-MSE=0.00992695298045874 | ||
[4] Validation-MSE=0.00748429435770959 | ||
[5] Train-MSE=0.00970045481808484 | ||
[5] Validation-MSE=0.00734121853020042 | ||
[6] Train-MSE=0.00956926480866969 | ||
[6] Validation-MSE=0.00723317882511765 | ||
[7] Train-MSE=0.00946674752049148 | ||
[7] Validation-MSE=0.00715298682916909 | ||
[8] Train-MSE=0.00936337062157691 | ||
[8] Validation-MSE=0.00708933407440782 | ||
[9] Train-MSE=0.00928824483416974 | ||
[9] Validation-MSE=0.00702768098562956 | ||
[10] Train-MSE=0.00921900537796319 | ||
[10] Validation-MSE=0.00698263343656436 | ||
[11] Train-MSE=0.00915476991795003 | ||
[11] Validation-MSE=0.00694422319065779 | ||
[12] Train-MSE=0.00911224479787052 | ||
[12] Validation-MSE=0.00691420421935618 | ||
[13] Train-MSE=0.0090605927631259 | ||
[13] Validation-MSE=0.00686828832840547 | ||
[14] Train-MSE=0.00901446407660842 | ||
[14] Validation-MSE=0.00685080053517595 | ||
[15] Train-MSE=0.0089907712303102 | ||
[15] Validation-MSE=0.00681731867371127 | ||
[16] Train-MSE=0.00894410968758166 | ||
[16] Validation-MSE=0.00680519745219499 | ||
[17] Train-MSE=0.00891360901296139 | ||
[17] Validation-MSE=0.00678778381552547 | ||
[18] Train-MSE=0.00887094167992473 | ||
[18] Validation-MSE=0.00675358629086986 | ||
[19] Train-MSE=0.00885531790554523 | ||
[19] Validation-MSE=0.00676276802551001 | ||
[20] Train-MSE=0.0088208335917443 | ||
[20] Validation-MSE=0.00674056768184528 | ||
[21] Train-MSE=0.00880425171926618 | ||
[21] Validation-MSE=0.00673307734541595 | ||
[22] Train-MSE=0.00879250690340996 | ||
[22] Validation-MSE=0.00670740590430796 | ||
[23] Train-MSE=0.00875497269444168 | ||
[23] Validation-MSE=0.00668720051180571 | ||
[24] Train-MSE=0.00873568719252944 | ||
[24] Validation-MSE=0.00669587979791686 | ||
[25] Train-MSE=0.00874641905538738 | ||
[25] Validation-MSE=0.00669469079002738 | ||
[26] Train-MSE=0.008697918523103 | ||
[26] Validation-MSE=0.00669995549833402 | ||
[27] Train-MSE=0.00869045881554484 | ||
[27] Validation-MSE=0.00670569541398436 | ||
[28] Train-MSE=0.00865633632056415 | ||
[28] Validation-MSE=0.00670662586344406 | ||
[29] Train-MSE=0.00868522766977549 | ||
[29] Validation-MSE=0.00668792036594823 | ||
[30] Train-MSE=0.0086129839066416 | ||
[30] Validation-MSE=0.00667576276464388 | ||
[31] Train-MSE=0.0086337742395699 | ||
[31] Validation-MSE=0.0067121529718861 | ||
[32] Train-MSE=0.00863495240919292 | ||
[32] Validation-MSE=0.0067587440717034 | ||
[33] Train-MSE=0.00863885483704507 | ||
[33] Validation-MSE=0.00670913810608909 | ||
[34] Train-MSE=0.00858410224318504 | ||
[34] Validation-MSE=0.00674143311334774 | ||
[35] Train-MSE=0.00860943677835166 | ||
[35] Validation-MSE=0.00671671854797751 | ||
[36] Train-MSE=0.00857279957272112 | ||
[36] Validation-MSE=0.00672605860745534 | ||
[37] Train-MSE=0.00857790051959455 | ||
[37] Validation-MSE=0.00671195174800232 | ||
[38] Train-MSE=0.00856402018107474 | ||
[38] Validation-MSE=0.00670708599500358 | ||
[39] Train-MSE=0.00855070641264319 | ||
[39] Validation-MSE=0.00669713690876961 | ||
[40] Train-MSE=0.00855873627588153 | ||
[40] Validation-MSE=0.00669847876997665 | ||
[41] Train-MSE=0.00854103988967836 | ||
[41] Validation-MSE=0.00672988337464631 | ||
[42] Train-MSE=0.00854658158496022 | ||
[42] Validation-MSE=0.0067430961644277 | ||
[43] Train-MSE=0.00850498480722308 | ||
[43] Validation-MSE=0.00670209160307422 | ||
[44] Train-MSE=0.00847653122618794 | ||
[44] Validation-MSE=0.00672520510852337 | ||
[45] Train-MSE=0.00853331410326064 | ||
[45] Validation-MSE=0.0066903488477692 | ||
[46] Train-MSE=0.0084140149410814 | ||
[46] Validation-MSE=0.00665930815739557 | ||
[47] Train-MSE=0.00842269244603813 | ||
[47] Validation-MSE=0.00667664298089221 | ||
[48] Train-MSE=0.00844420134089887 | ||
[48] Validation-MSE=0.00665349006885663 | ||
[49] Train-MSE=0.00839704093523324 | ||
[49] Validation-MSE=0.00666191370692104 | ||
[50] Train-MSE=0.00840363306924701 | ||
[50] Validation-MSE=0.00664306507678702 | ||
user system elapsed | ||
66.782 6.229 39.745 | ||
|
||
``` | ||
|
||
|
||
Inference on the network | ||
--------- | ||
Now we have trained the network. Let's use it for inference | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. inference. |
||
|
||
```r | ||
ctx <- mx.cpu() | ||
|
||
## We extract the state symbols for RNN | ||
|
||
internals <- model$symbol$get.internals() | ||
sym_state <- internals$get.output(which(internals$outputs %in% "RNN_state")) | ||
sym_state_cell <- internals$get.output(which(internals$outputs %in% "RNN_state_cell")) | ||
sym_output <- internals$get.output(which(internals$outputs %in% "loss_output")) | ||
symbol <- mx.symbol.Group(sym_output, sym_state, sym_state_cell) | ||
|
||
## We will predict 100 timestamps for 401stsamples since it was not used in training | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 401stsamples? |
||
pred_length = 100 | ||
predict <- numeric() | ||
|
||
## We pass the 400th sample through the network to get the weights and use it for predicting next 100 time stamps. | ||
data = mx.nd.array(trainX[, , 400, drop = F]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. maybe I missed this but I'm not sure what's going on with the labels... |
||
label = mx.nd.array(trainY[, 400, drop = F]) | ||
|
||
|
||
|
||
infer.data <- mx.io.arrayiter(data = data, label = label, | ||
batch.size = 1, shuffle = FALSE) | ||
|
||
infer <- mx.infer.rnn.one(infer.data = infer.data, | ||
symbol = symbol, | ||
arg.params = model$arg.params, | ||
aux.params = model$aux.params, | ||
input.params = NULL, | ||
ctx = ctx) | ||
## Once we get the weights for the above time seriees, we try to predict the next 100 steps for this time series which is technically | ||
##our 401st time series | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. spacing on the comment (but does this need its own line?) |
||
|
||
real<- trainY[ ,401] | ||
|
||
## Now we iterate one by one to generate each of the next timestamp pollution values | ||
|
||
for (i in 1:pred_length) { | ||
|
||
data = mx.nd.array(trainX[, i, 401, drop = F]) | ||
label = mx.nd.array(trainY[i,401, drop = F]) | ||
infer.data <- mx.io.arrayiter(data = data, label = label, | ||
batch.size = 1, shuffle = FALSE) | ||
## note that we use rnn state values from previous iterations here | ||
infer <- mx.infer.rnn.one(infer.data = infer.data, | ||
symbol = symbol, | ||
ctx = ctx, | ||
arg.params = model$arg.params, | ||
aux.params = model$aux.params, | ||
input.params = list(rnn.state = infer[[2]], | ||
rnn.state.cell = infer[[3]])) | ||
|
||
pred <- infer[[1]] | ||
predict <- c(predict, as.numeric(as.array(pred))) | ||
|
||
} | ||
|
||
``` | ||
Now predict contains the predicted 100 values | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why not showing a plot instead of hundreds of numbers? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. added plots |
||
|
||
``` | ||
> predict | ||
[1] 0.07202110 0.07473785 0.07704198 0.07890865 0.07866135 0.07802615 0.07835323 0.07779201 | ||
[9] 0.07753669 0.07813390 0.07946113 0.08086129 0.08315894 0.08248840 0.08063691 0.07665294 | ||
[17] 0.07449745 0.07181541 0.07004550 0.06852350 0.06651592 0.06582737 0.06484741 0.06473041 | ||
[25] 0.06570850 0.06849141 0.07106550 0.07429304 0.07519219 0.07302019 0.07029570 0.07020824 | ||
[33] 0.07033765 0.07064262 0.07164756 0.07258270 0.07392646 0.07482745 0.07253803 0.06833689 | ||
[41] 0.06562686 0.06298455 0.06118675 0.06004642 0.05913667 0.05937991 0.06084058 0.06260598 | ||
[49] 0.06461339 0.06610494 0.06790128 0.06700534 0.06526335 0.06434848 0.06419372 0.06458106 | ||
[57] 0.06451128 0.06394381 0.06342814 0.06225721 0.06086770 0.06100054 0.06181625 0.06202789 | ||
[65] 0.06179357 0.06211190 0.06205740 0.06191149 0.06142320 0.06112600 0.06098184 0.06103123 | ||
[73] 0.06323517 0.06354101 0.06597340 0.06735678 0.06951007 0.07118392 0.07365122 0.07625698 | ||
[81] 0.07842132 0.07991818 0.07907465 0.07529922 0.07086860 0.06635273 0.06279082 0.05998828 | ||
[89] 0.05896929 0.05801799 0.05909069 0.05821043 0.05980247 0.06013399 0.06061675 0.05972625 | ||
[97] 0.06003752 0.06044054 0.06041266 0.06089244 | ||
|
||
> real | ||
[1] 0.20020121 0.21327968 0.21227364 0.20221328 0.19416499 0.19919517 0.18812877 0.18511066 | ||
[9] 0.19315895 0.20925553 0.22032193 0.25150905 0.22837022 0.21126761 0.17706237 0.18108652 | ||
[17] 0.16498994 0.16398390 0.15593561 0.13480885 0.13480885 0.12273642 0.12575453 0.13581489 | ||
[25] 0.16800805 0.18812877 0.20724346 0.19718310 0.15492958 0.12072435 0.13581489 0.13279678 | ||
[33] 0.13682093 0.14688129 0.15995976 0.18008048 0.18611670 0.15291751 0.11267606 0.12374245 | ||
[41] 0.10965795 0.10362173 0.09859155 0.09456740 0.10663984 0.11569416 0.12474849 0.13480885 | ||
[49] 0.13480885 0.13782696 0.10160966 0.07847082 0.07947686 0.08551308 0.09356137 0.08853119 | ||
[57] 0.07847082 0.06941650 0.06237425 0.05734406 0.07746479 0.08752515 0.09758551 0.10160966 | ||
[65] 0.11267606 0.10965795 0.11066398 0.10462777 0.10462777 0.10160966 0.10060362 0.11368209 | ||
[73] 0.11569416 0.12173038 0.11971831 0.13883300 0.14688129 0.15895372 0.18511066 0.19114688 | ||
[81] 0.20221328 0.17605634 0.13279678 0.10261569 0.08853119 0.07444668 0.07444668 0.08249497 | ||
[89] 0.07344064 0.09557344 0.07645875 0.10261569 0.07243461 0.06941650 0.05231388 0.05432596 | ||
[97] 0.06740443 0.06338028 0.06740443 0.07444668 | ||
``` | ||
The above tutorial is just for demonstration purposes and have not beeen tuned extensively for accuracy. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
...we used open source pollution data
and that link is wrong... goes to a shakespeare dataset, plus I'd move the link to be [pollution data] and then delete the last line.