Skip to content

Commit 6cbe612

Browse files
mikemahoney218hfrickjuliasilge
authored
Add random-split block CV (#20)
* Add random-split block CV * Wrap multi-part warnings in c() * Apply suggestions from code review Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com> * Extract sgbp checker * Remove rdname * Changes from PR review * Apply suggestions from code review Co-authored-by: Julia Silge <julia.silge@gmail.com> * Redocument Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com> Co-authored-by: Julia Silge <julia.silge@gmail.com>
1 parent d26e3b7 commit 6cbe612

11 files changed

+327
-40
lines changed

NAMESPACE

+3-1
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method(dplyr_reconstruct,spatial_clustering_cv)
4+
S3method(pretty,spatial_block_cv)
45
S3method(pretty,spatial_clustering_cv)
6+
S3method(print,spatial_block_cv)
57
S3method(print,spatial_clustering_cv)
68
S3method(vec_cast,data.frame.spatial_clustering_cv)
79
S3method(vec_cast,spatial_clustering_cv.data.frame)
@@ -16,7 +18,7 @@ S3method(vec_ptype2,tbl_df.spatial_clustering_cv)
1618
S3method(vec_restore,spatial_clustering_cv)
1719
export(analysis)
1820
export(assessment)
19-
export(pretty.spatial_clustering_cv)
21+
export(spatial_block_cv)
2022
export(spatial_clustering_cv)
2123
import(vctrs)
2224
importFrom(dplyr,dplyr_reconstruct)

NEWS.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# spatialsample (development version)
22

3+
* `spatial_block_cv()` is a new function for performing spatial block cross validation.
4+
It currently supports randomly assigning blocks to folds.
5+
36
* `spatial_clustering_cv()` gains an argument, `cluster_function`, which
47
specifies what type of clustering to perform. `cluster_function = "kmeans"`,
58
the default, uses `stats::kmeans()` for k-means clustering, while

R/block_cv.R

+172
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
#' Spatial block cross-validation
2+
#'
3+
#' Block cross-validation splits the area of your data into a number of
4+
#' grid cells, or "blocks", and then assigns all data into folds based on the
5+
#' blocks they fall into.
6+
#'
7+
#' @details
8+
#' The grid blocks can be controlled by passing arguments to
9+
#' [sf::st_make_grid()] via `...`. Some particularly useful arguments include:
10+
#'
11+
#' * `cellsize` Target cellsize, expressed as the "diameter" (shortest
12+
#' straight-line distance between opposing sides; two times the apothem)
13+
#' of each block, in map units.
14+
#' * `n` The number of grid blocks in the x and y direction (columns, rows).
15+
#' * `square` A logical value indicating whether to create square (`TRUE`) or
16+
#' hexagonal (`FALSE`) cells.
17+
#'
18+
#' If both `cellsize` and `n` are provided, then the number of blocks requested
19+
#' by `n` of sizes specified by `cellsize` will be returned, likely not
20+
#' lining up with the bounding box of `data`. If only `cellsize`
21+
#' is provided, this function will return as many blocks of size
22+
#' `cellsize` as fit inside the bounding box of `data`. If only `n` is provided,
23+
#' then `cellsize` will be automatically adjusted to create the requested
24+
#' number of cells.
25+
#'
26+
#' @param data An object of class `sf` or `sfc`.
27+
#' @param method The method used to sample blocks for cross validation folds.
28+
#' Currently, only `"random"` is supported.
29+
#' @inheritParams rsample::vfold_cv
30+
#' @param ... Arguments passed to [sf::st_make_grid()].
31+
#'
32+
#' @return A tibble with classes `spatial_block_cv`, `rset`, `tbl_df`, `tbl`,
33+
#' and `data.frame`. The results include a column for the data split objects
34+
#' and an identification variable `id`.
35+
#'
36+
#' @examples
37+
#' data(Smithsonian, package = "modeldata")
38+
#' smithsonian_sf <- sf::st_as_sf(Smithsonian,
39+
#' coords = c("longitude", "latitude"),
40+
#' # Set CRS to WGS84
41+
#' crs = 4326)
42+
#'
43+
#' spatial_block_cv(smithsonian_sf, v = 3)
44+
#'
45+
#' @references
46+
#'
47+
#' D. R. Roberts, V. Bahn, S. Ciuti, M. S. Boyce, J. Elith, G. Guillera-Arroita,
48+
#' S. Hauenstein, J. J. Lahoz-Monfort, B. Schröder, W. Thuiller, D. I. Warton,
49+
#' B. A. Wintle, F. Hartig, and C. F. Dormann. "Cross-validation strategies for
50+
#' data with temporal, spatial, hierarchical, or phylogenetic structure," 2016,
51+
#' Ecography 40(8), pp. 913-929, doi: 10.1111/ecog.02881.
52+
#'
53+
#' @export
54+
spatial_block_cv <- function(data, method = "random", v = 10, ...) {
55+
method <- rlang::arg_match(method)
56+
57+
if (!"sf" %in% class(data)) {
58+
rlang::abort(
59+
c(
60+
"`spatial_block_cv()` currently only supports `sf` objects.",
61+
i = "Try converting `data` to an `sf` object via `sf::st_as_sf()`."
62+
)
63+
)
64+
}
65+
66+
if (sf::st_crs(data) == sf::NA_crs_) {
67+
rlang::abort(
68+
c(
69+
"`spatial_block_cv()` requires your data to have an appropriate coordinate reference system (CRS).",
70+
i = "Try setting a CRS using `sf::st_set_crs()`."
71+
)
72+
)
73+
}
74+
75+
grid_box <- sf::st_bbox(data)
76+
if (sf::st_is_longlat(data)) {
77+
# cf https://github.com/ropensci/stplanr/pull/467
78+
# basically: spherical geometry means sometimes the straight line of the
79+
# grid will exclude points within the bounding box
80+
#
81+
# so here we'll expand our boundary by 0.1% in order to always contain our
82+
# points within the grid
83+
grid_box[1] <- grid_box[1] - abs(grid_box[1] * 0.001)
84+
grid_box[2] <- grid_box[2] - abs(grid_box[2] * 0.001)
85+
grid_box[3] <- grid_box[3] + abs(grid_box[3] * 0.001)
86+
grid_box[4] <- grid_box[4] + abs(grid_box[4] * 0.001)
87+
}
88+
grid_blocks <- sf::st_make_grid(grid_box, ...)
89+
split_objs <- switch(
90+
method,
91+
"random" = random_block_cv(data, grid_blocks, v)
92+
)
93+
v <- split_objs$v[[1]]
94+
split_objs$v <- NULL
95+
96+
## We remove the holdout indices since it will save space and we can
97+
## derive them later when they are needed.
98+
split_objs$splits <- map(split_objs$splits, rm_out)
99+
100+
## Save some overall information
101+
cv_att <- list(v = v)
102+
103+
new_rset(
104+
splits = split_objs$splits,
105+
ids = split_objs[, grepl("^id", names(split_objs))],
106+
attrib = cv_att,
107+
subclass = c("spatial_block_cv", "rset")
108+
)
109+
110+
}
111+
112+
random_block_cv <- function(data, grid_blocks, v) {
113+
n <- nrow(data)
114+
115+
block_contains_points <- purrr::map_lgl(
116+
sf::st_intersects(grid_blocks, data),
117+
sgbp_is_not_empty
118+
)
119+
grid_blocks <- grid_blocks[block_contains_points]
120+
121+
n_blocks <- length(grid_blocks)
122+
if (!is.numeric(v) || length(v) != 1) {
123+
rlang::abort("`v` must be a single integer.")
124+
}
125+
if (v > n_blocks) {
126+
rlang::warn(paste0(
127+
"Fewer than ", v, " blocks available for sampling; setting v to ",
128+
n_blocks, "."
129+
))
130+
v <- n_blocks
131+
}
132+
133+
grid_blocks <- sf::st_as_sf(grid_blocks)
134+
grid_blocks$fold <- sample(rep(seq_len(v), length.out = nrow(grid_blocks)))
135+
grid_blocks <- split_unnamed(grid_blocks, grid_blocks$fold)
136+
137+
# grid_blocks is now a list of sgbp lists (?sf::sgbp)
138+
#
139+
# The first map() here iterates through the meta-list,
140+
# and the second checks each element of the relevant sgbp list
141+
# to see if it is integer(0) (no intersections) or not
142+
#
143+
# Each sgbp sub-list is nrow(data) elements long, so this which()
144+
# returns the list indices which are not empty, which is equivalent
145+
# to the row numbers that intersect with blocks in the fold
146+
indices <- purrr::map(
147+
grid_blocks,
148+
function(blocks) which(
149+
purrr::map_lgl(
150+
sf::st_intersects(data, blocks),
151+
sgbp_is_not_empty
152+
)
153+
)
154+
)
155+
156+
indices <- lapply(indices, default_complement, n = n)
157+
split_objs <- purrr::map(
158+
indices,
159+
make_splits,
160+
data = data,
161+
class = "spatial_block_split"
162+
)
163+
tibble::tibble(
164+
splits = split_objs,
165+
id = names0(length(split_objs), "Fold"),
166+
v = v
167+
)
168+
}
169+
170+
# Check sparse geometry binary predicate for empty elements
171+
# See ?sf::sgbp for more information on the data structure
172+
sgbp_is_not_empty <- function(x) !identical(x, integer(0))

R/labels.R

+21-11
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,27 @@
1-
#' Short Descriptions of spatial rsets
2-
#'
3-
#' Produce a character vector describing the spatial resampling method.
4-
#'
5-
#' @param x An `rset` object
6-
#' @param ... Not currently used.
7-
#' @return A character vector.
8-
#' @export pretty.spatial_clustering_cv
91
#' @export
10-
#' @method pretty spatial_clustering_cv
11-
#' @rdname pretty.spatial_clustering_cv
12-
#' @keywords internal
132
pretty.spatial_clustering_cv <- function(x, ...) {
143
details <- attributes(x)
154
res <- paste0(details$v, "-fold spatial cross-validation")
165
res
176
}
7+
8+
#' @export
9+
print.spatial_clustering_cv <- function(x, ...) {
10+
cat("# ", pretty(x), "\n")
11+
class(x) <- class(x)[!(class(x) %in% c("spatial_clustering_cv", "rset"))]
12+
print(x, ...)
13+
}
14+
15+
#' @export
16+
pretty.spatial_block_cv <- function(x, ...) {
17+
details <- attributes(x)
18+
res <- paste0(details$v, "-fold spatial block cross-validation")
19+
res
20+
}
21+
22+
#' @export
23+
print.spatial_block_cv <- function(x, ...) {
24+
cat("# ", pretty(x), "\n")
25+
class(x) <- class(x)[!(class(x) %in% c("spatial_block_cv", "rset"))]
26+
print(x, ...)
27+
}

R/spatial_clustering_cv.R

-7
Original file line numberDiff line numberDiff line change
@@ -128,10 +128,3 @@ spatial_clustering_splits <- function(data, dists, v = 10, cluster_function = c(
128128
id = names0(length(split_objs), "Fold")
129129
)
130130
}
131-
132-
#' @export
133-
print.spatial_clustering_cv <- function(x, ...) {
134-
cat("# ", pretty(x), "\n")
135-
class(x) <- class(x)[!(class(x) %in% c("spatial_clustering_cv", "rset"))]
136-
print(x, ...)
137-
}

_pkgdown.yml

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ reference:
2121
- title: Resampling Methods
2222
contents:
2323
- spatial_clustering_cv
24+
- spatial_block_cv
2425
- title: Utilities
2526
contents:
2627
- reexports

man/pretty.spatial_clustering_cv.Rd

-20
This file was deleted.

man/spatial_block_cv.Rd

+65
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

spatialsample.Rproj

+1-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ AlwaysSaveHistory: Default
66

77
EnableCodeIndexing: Yes
88
UseSpacesForTab: Yes
9-
NumSpacesForTab: 4
9+
NumSpacesForTab: 2
1010
Encoding: UTF-8
1111

1212
RnwWeave: Sweave

tests/testthat/_snaps/block_cv.md

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
# printing
2+
3+
# 10-fold spatial block cross-validation
4+
# A tibble: 10 x 2
5+
splits id
6+
<list> <chr>
7+
1 <split [2082/848]> Fold01
8+
2 <split [2570/360]> Fold02
9+
3 <split [2801/129]> Fold03
10+
4 <split [2848/82]> Fold04
11+
5 <split [2822/108]> Fold05
12+
6 <split [2685/245]> Fold06
13+
7 <split [2216/714]> Fold07
14+
8 <split [2836/94]> Fold08
15+
9 <split [2609/321]> Fold09
16+
10 <split [2901/29]> Fold10
17+

0 commit comments

Comments
 (0)