Skip to content

Commit c9f09b9

Browse files
committed
small
1 parent 5371606 commit c9f09b9

File tree

3 files changed

+78
-51
lines changed

3 files changed

+78
-51
lines changed

content/model-improvement/assessment-model-improvement.qmd

+3-1
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,9 @@ Follow the [Evaluate your model with resampling section](https://www.tidymodels.
104104

105105
Compute the RMSE for both models again. Of course nothing changes for the null model. Compare the new RMSE estimates obtained through CV with those obtained earlier. What did and didn't change?
106106

107-
Run the code again that creates the CV folds and does the fitting. This time, choose a different value for the random seed. The RMSE values for the CV fits will change. That's just due to the randomness in the data splitting. If we had more data, we would expect to get less variability. The overall pattern between changes in the RMSE values for the fits to the training data without CV, and what we see with CV, should still be the same.
107+
Also look at the standard error for the RMSE. Since you are now sampling, you not only get a single estimate for RMSE, but one for each sample, so you can look at the variation in RMSE. This gives you a good indication of how robust your model performance is.
108+
109+
Finally, run the code again that creates the CV folds and does the fitting. This time, choose a different value for the random seed. The RMSE values for the CV fits will change. That's just due to the randomness in the data splitting. If we had more data, we would expect to get less variability. The overall pattern between changes in the RMSE values for the fits to the training data without CV, and what we see with CV, should still be the same.
108110

109111
::: note
110112
If you want to get more robust RMSE estimates with CV, you can try to set `repeats` to some value. That creates more samples by repeating the whole CV procedure several times. In theory this might give more robust results. You might encounter some warning messages. This is likely related that occasionally, by chance, data is split in a way that some information (e.g., a certain value for `SEX` in our data) is missing from one one of the folds. That can cause issues.

myresources/improvement-exercise-solution/improvement-exercise.R

+75-50
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,9 @@ Ntest = nrow(test_data)
2626
## ---- model1 --------
2727
mod <- linear_reg() %>% set_engine("lm")
2828
wflow1 <-
29-
workflow() %>%
30-
add_model(mod) %>%
31-
add_formula(Y ~ DOSE)
29+
workflow() %>%
30+
add_model(mod) %>%
31+
add_formula(Y ~ DOSE)
3232
fit1 <- wflow1 %>% fit(data = train_data)
3333

3434

@@ -46,20 +46,20 @@ pred0 <- rep(mean(train_data$Y),Ntrain)
4646
## ---- rmse --------
4747
# Compute the RMSE and R squared for model 1
4848
rmse_train_1 <- bind_cols(train_data, pred1) %>%
49-
rmse(truth = Y, estimate = .pred)
49+
rmse(truth = Y, estimate = .pred)
5050

5151
# Compute the RMSE and R squared for model 2
5252
rmse_train_2 <- bind_cols(train_data, pred2) %>%
53-
rmse(truth = Y, estimate = .pred)
53+
rmse(truth = Y, estimate = .pred)
5454

5555
# Compute RMSE for a dumb null model
5656
rmse_train_0 <- rmse_vec(truth = train_data$Y, estimate = pred0)
5757

5858
# Print the results
5959
metrics = data.frame(model = c("null model","model 1","model 2"),
60-
rmse = c(rmse_train_0,
61-
rmse_train_1$.estimate,
62-
rmse_train_2$.estimate) )
60+
rmse = c(rmse_train_0,
61+
rmse_train_1$.estimate,
62+
rmse_train_2$.estimate) )
6363
print(metrics)
6464

6565
## ---- cross-validation --------
@@ -74,35 +74,35 @@ rmse_cv_2 <- collect_metrics(fit2_cv)$mean[1]
7474

7575
# Print the results
7676
metrics_cv = data.frame(model = c("null","model 1","model 2"),
77-
rmse = c(rmse_train_0, rmse_cv_1, rmse_cv_2) )
77+
rmse = c(rmse_train_0, rmse_cv_1, rmse_cv_2) )
7878
print(metrics_cv)
7979

8080
## ---- obs-pred-plot --------
8181
pred0a <- data.frame(predicted = pred0, model = "model 0")
82-
pred1a <- data.frame(predicted = as.numeric(unlist(pred1)), model = "model 1")
83-
pred2a <- data.frame(predicted = as.numeric(unlist(pred2)), model = "model 2")
82+
pred1a <- data.frame(predicted = pred1$.pred, model = "model 1")
83+
pred2a <- data.frame(predicted = pred2$.pred, model = "model 2")
8484

8585
plot_data <- bind_rows(pred0a,pred1a,pred2a) %>%
86-
mutate(observed = rep(train_data$Y,3))
86+
mutate(observed = rep(train_data$Y,3))
8787

8888
p1 <- plot_data %>% ggplot() +
89-
geom_point(aes(x = observed, y = predicted, color = model, shape = model)) +
90-
labs(x = "Observed", y = "Predicted", title = "Predicted vs Observed") +
91-
geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "black") +
92-
scale_x_continuous(limits=c(0,5000)) +
93-
scale_y_continuous(limits=c(0,5000)) +
94-
theme_minimal()
89+
geom_point(aes(x = observed, y = predicted, color = model, shape = model)) +
90+
labs(x = "Observed", y = "Predicted", title = "Predicted vs Observed") +
91+
geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "black") +
92+
scale_x_continuous(limits=c(0,5000)) +
93+
scale_y_continuous(limits=c(0,5000)) +
94+
theme_minimal()
9595
plot(p1)
9696

9797

9898
## ---- residuals-plot --------
9999
plot_data1 <- plot_data |> mutate(residuals = predicted-observed) |> filter(model == "model 2")
100100
p1a <- plot_data1 %>% ggplot() +
101-
geom_point(aes(x = predicted, y = residuals, color = model, shape = model)) +
102-
labs(x = "Predicted", y = "Residuals", title = "Residuals vs Predicted") +
103-
geom_abline(intercept = 0, slope = 0, linetype = "dashed", color = "black") +
104-
scale_y_continuous(limits=c(-2500,2500)) +
105-
theme_minimal()
101+
geom_point(aes(x = predicted, y = residuals, color = model, shape = model)) +
102+
labs(x = "Predicted", y = "Residuals", title = "Residuals vs Predicted") +
103+
geom_abline(intercept = 0, slope = 0, linetype = "dashed", color = "black") +
104+
scale_y_continuous(limits=c(-2500,2500)) +
105+
theme_minimal()
106106
plot(p1a)
107107

108108

@@ -112,52 +112,77 @@ plot(p1a)
112112
Nsamp = 100 #number of samples
113113
set.seed(rngseed)
114114
# create samples
115-
dat_bs <- train_data |> rsample::bootstraps(times = Nsamp, apparent = TRUE)
115+
dat_bs <- train_data |> rsample::bootstraps(times = Nsamp)
116116

117117
#set up empty arrays to store predictions for each sample
118118
pred_bs = array(0, dim=c(Nsamp,Ntrain))
119119

120120
#loop over each bootstrap sample, fit model, then predict and record predictions
121-
for (i in 1:Nsamp)
122-
{
123-
dat_sample = rsample::analysis(dat_bs$splits[[i]])
124-
fit_bs <- wflow2 |> fit(data = dat_sample)
125-
pred_bs[i,] <- fit_bs %>% predict(train_data) %>% unlist()
121+
for (i in 1:Nsamp) {
122+
dat_sample = rsample::analysis(dat_bs$splits[[i]])
123+
fit_bs <- wflow2 |> fit(data = dat_sample)
124+
pred_df <- fit_bs %>% predict(train_data)
125+
pred_bs[i,] <- pred_df$.pred %>% unlist()
126126
}
127127

128-
#compute mean and 89% confidence interval for predictions
128+
#compute median and 89% confidence interval for predictions
129129
preds <- pred_bs |> apply(2, quantile, c(0.055, 0.5, 0.945)) |> t()
130130

131131

132132
#make plot showing uncertainty
133-
plot_data2 <- data.frame(median = preds[,2], lb = preds[,1],
134-
ub = preds[,3], observed = rep(train_data$Y,3), mean = pred2a$predicted)
133+
plot_data2 <- data.frame(
134+
median = preds[,2],
135+
lb = preds[,1],
136+
ub = preds[,3],
137+
observed = rep(train_data$Y,3),
138+
mean = pred2a$predicted
139+
)
135140

136141
p2 <- plot_data2 %>% ggplot() +
137-
geom_point(aes(x = observed, y = median), shape = 5, color = "blue") +
138-
geom_point(aes(x = observed, y = lb), shape = 4, color = "red") +
139-
geom_point(aes(x = observed, y = ub), shape = 4, color = "red") +
140-
geom_point(aes(x = observed, y = mean), shape = 6, color = "black") +
141-
labs(x = "Observed", y = "Predicted", title = "Predicted vs Observed") +
142-
geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "black") +
143-
scale_x_continuous(limits=c(0,5000)) +
144-
scale_y_continuous(limits=c(0,5000)) +
145-
theme_minimal()
142+
geom_errorbar(aes(x = observed, ymin = lb, ymax = ub), width = 25) +
143+
geom_point(
144+
aes(x = observed, y = median, color = "median"),
145+
shape = 5
146+
) +
147+
geom_point(
148+
aes(x = observed, y = mean, color = "mean"),
149+
shape = 6
150+
) +
151+
labs(x = "Observed", y = "Predicted", title = "Predicted vs Observed") +
152+
geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "black") +
153+
scale_x_continuous(limits=c(0,5000)) +
154+
scale_y_continuous(limits=c(0,5000)) +
155+
scale_color_manual(name = "stat", values = c("orange", "blue")) +
156+
theme_minimal()
146157
plot(p2)
147158

148159

149160

150161
## ---- final testing --------
151-
predf <- fit2 %>% predict(test_data)
162+
predf <- fit2 %>% predict(test_data)
152163
plot_f <- predf %>% mutate(observed = rep(test_data$Y,1)) %>% rename(predicted = .pred)
153164

154-
p3 <- ggplot() +
155-
geom_point(aes(x = observed, y = predicted), data = plot_data, color="black") +
156-
geom_point(aes(x = observed, y = predicted), data = plot_f, color="red", shape = 15) +
157-
labs(x = "Observed", y = "Predicted", title = "Predicted vs Observed") +
158-
geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "black") +
159-
scale_x_continuous(limits=c(0,5000)) +
160-
scale_y_continuous(limits=c(0,5000)) +
161-
theme_minimal()
165+
final_plot_data <-
166+
dplyr::bind_rows(
167+
"train" = dplyr::filter(plot_data, model == "model 2"),
168+
"test" = plot_f,
169+
.id = "set"
170+
) |>
171+
tibble::tibble() |>
172+
dplyr::select(-model)
173+
174+
p3 <- ggplot(final_plot_data) +
175+
aes(
176+
x = observed,
177+
y = predicted,
178+
color = set,
179+
shape = set
180+
) +
181+
geom_point() +
182+
labs(x = "Observed", y = "Predicted", title = "Predicted vs Observed") +
183+
geom_abline(intercept = 0, slope = 1, linetype = "dashed", color = "black") +
184+
scale_x_continuous(limits=c(0,5000)) +
185+
scale_y_continuous(limits=c(0,5000)) +
186+
theme_minimal()
162187
plot(p3)
163188

File renamed without changes.

0 commit comments

Comments
 (0)