Skip to content

Commit fe2c3f4

Browse files
Add first draft of NNDM function (#141)
* Add first draft of NNDM function * Fix tests * Fix tests, site * Clean up tests * Pass call explicitly * Clean up * Update snaps * Fix several comments * Final fixes * One more comment
1 parent 0d25a04 commit fe2c3f4

15 files changed

+2078
-1393
lines changed

NAMESPACE

+12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ S3method(vec_cast,data.frame.spatial_block_cv)
1515
S3method(vec_cast,data.frame.spatial_buffer_vfold_cv)
1616
S3method(vec_cast,data.frame.spatial_clustering_cv)
1717
S3method(vec_cast,data.frame.spatial_leave_location_out_cv)
18+
S3method(vec_cast,data.frame.spatial_nndm_cv)
1819
S3method(vec_cast,spatial_block_cv.data.frame)
1920
S3method(vec_cast,spatial_block_cv.spatial_block_cv)
2021
S3method(vec_cast,spatial_block_cv.tbl_df)
@@ -27,14 +28,19 @@ S3method(vec_cast,spatial_clustering_cv.tbl_df)
2728
S3method(vec_cast,spatial_leave_location_out_cv.data.frame)
2829
S3method(vec_cast,spatial_leave_location_out_cv.spatial_leave_location_out_cv)
2930
S3method(vec_cast,spatial_leave_location_out_cv.tbl_df)
31+
S3method(vec_cast,spatial_nndm_cv.data.frame)
32+
S3method(vec_cast,spatial_nndm_cv.spatial_nndm_cv)
33+
S3method(vec_cast,spatial_nndm_cv.tbl_df)
3034
S3method(vec_cast,tbl_df.spatial_block_cv)
3135
S3method(vec_cast,tbl_df.spatial_buffer_vfold_cv)
3236
S3method(vec_cast,tbl_df.spatial_clustering_cv)
3337
S3method(vec_cast,tbl_df.spatial_leave_location_out_cv)
38+
S3method(vec_cast,tbl_df.spatial_nndm_cv)
3439
S3method(vec_ptype2,data.frame.spatial_block_cv)
3540
S3method(vec_ptype2,data.frame.spatial_buffer_vfold_cv)
3641
S3method(vec_ptype2,data.frame.spatial_clustering_cv)
3742
S3method(vec_ptype2,data.frame.spatial_leave_location_out_cv)
43+
S3method(vec_ptype2,data.frame.spatial_nndm_cv)
3844
S3method(vec_ptype2,spatial_block_cv.data.frame)
3945
S3method(vec_ptype2,spatial_block_cv.spatial_block_cv)
4046
S3method(vec_ptype2,spatial_block_cv.tbl_df)
@@ -47,21 +53,27 @@ S3method(vec_ptype2,spatial_clustering_cv.tbl_df)
4753
S3method(vec_ptype2,spatial_leave_location_out_cv.data.frame)
4854
S3method(vec_ptype2,spatial_leave_location_out_cv.spatial_leave_location_out_cv)
4955
S3method(vec_ptype2,spatial_leave_location_out_cv.tbl_df)
56+
S3method(vec_ptype2,spatial_nndm_cv.data.frame)
57+
S3method(vec_ptype2,spatial_nndm_cv.spatial_nndm_cv)
58+
S3method(vec_ptype2,spatial_nndm_cv.tbl_df)
5059
S3method(vec_ptype2,tbl_df.spatial_block_cv)
5160
S3method(vec_ptype2,tbl_df.spatial_buffer_vfold_cv)
5261
S3method(vec_ptype2,tbl_df.spatial_clustering_cv)
5362
S3method(vec_ptype2,tbl_df.spatial_leave_location_out_cv)
63+
S3method(vec_ptype2,tbl_df.spatial_nndm_cv)
5464
S3method(vec_restore,spatial_block_cv)
5565
S3method(vec_restore,spatial_buffer_vfold_cv)
5666
S3method(vec_restore,spatial_clustering_cv)
5767
S3method(vec_restore,spatial_leave_location_out_cv)
68+
S3method(vec_restore,spatial_nndm_cv)
5869
export(analysis)
5970
export(assessment)
6071
export(autoplot)
6172
export(spatial_block_cv)
6273
export(spatial_buffer_vfold_cv)
6374
export(spatial_clustering_cv)
6475
export(spatial_leave_location_out_cv)
76+
export(spatial_nndm_cv)
6577
import(sf)
6678
import(vctrs)
6779
importFrom(dplyr,dplyr_reconstruct)

NEWS.md

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
# spatialsample (development version)
22

3+
* `spatial_nndm_cv()` is a new function for nearest neighbor distance matching
4+
cross-validation, as described in Milà et al. 2022
5+
(doi: 10.1111/2041-210X.13851). NNDM was first implemented in CAST
6+
(https://cran.r-project.org/package=CAST).
7+
38
# spatialsample 0.3.0
49

510
## Breaking changes

R/compat-vctrs-helpers.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ delayedAssign("rset_subclasses", {
1717
spatial_block_cv = spatial_block_cv(test_data()),
1818
spatial_clustering_cv = spatial_clustering_cv(test_data()),
1919
spatial_buffer_vfold_cv = spatial_buffer_vfold_cv(test_data(), radius = 1, buffer = 1),
20-
spatial_leave_location_out_cv = spatial_leave_location_out_cv(test_data(), idx)
20+
spatial_leave_location_out_cv = spatial_leave_location_out_cv(test_data(), idx),
21+
spatial_nndm_cv = spatial_nndm_cv(test_data()[1:500, ], test_data()[501:682, ])
2122
)
2223
)
2324
} else {

R/spatial_nndm_cv.R

+269
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
#' Nearest neighbor distance matching (NNDM) cross-validation
2+
#'
3+
#' NNDM is a variant of leave-one-out cross-validation which assigns each
4+
#' observation to a single assessment fold, and then attempts to remove data
5+
#' from each analysis fold until the nearest neighbor distance distribution
6+
#' between assessment and analysis folds matches the nearest neighbor distance
7+
#' distribution between training data and the locations a model will be used to
8+
#' predict.
9+
#' Proposed by Milà et al. (2022), this method aims to provide accurate
10+
#' estimates of how well models will perform in the locations they will actually
11+
#' be predicting. This method was originally implemented in the CAST package.
12+
#'
13+
#' Note that, as a form of leave-one-out cross-validation, this method can be
14+
#' rather slow for larger data (and fitting models to these resamples will be
15+
#' even slower).
16+
#'
17+
#' @param data An object of class `sf` or `sfc`.
18+
#' @param prediction_sites An `sf` or `sfc` object describing the areas to be
19+
#' predicted. If `prediction_sites` are all points, then those points are
20+
#' treated as the intended prediction points when calculating target nearest
21+
#' neighbor distances. If any element of `prediction_sites` is not a single
22+
#' point, then points are sampled from within the bounding box of
23+
#' `prediction_sites` and those points are then used as the intended prediction
24+
#' points.
25+
#' @param ... Additional arguments passed to [sf::st_sample()]. Note that the
26+
#' number of points to sample is controlled by `prediction_sample_size`; trying
27+
#' to pass `size` via `...` will cause an error.
28+
#' @param autocorrelation_range A numeric of length 1 representing the landscape
29+
#' autocorrelation range ("phi" in the terminology of Milà et al. (2022)). If
30+
#' `NULL`, the default, the autocorrelation range is assumed to be the distance
31+
#' between the opposite corners of the bounding box of `prediction_sites`.
32+
#' @param prediction_sample_size A numeric of length 1: the number of points to
33+
#' sample when `prediction_sites` is not only composed of points. Note that this
34+
#' argument is passed to `size` in [sf::st_sample()], meaning that no elements
35+
#' of `...` can be named `size`.
36+
#' @param min_analysis_proportion The minimum proportion of `data` that must
37+
#' remain after removing points to match nearest neighbor distances. This
38+
#' function will stop removing data from analysis sets once only
39+
#' `min_analysis_proportion` of the original data remains in analysis sets, even
40+
#' if the nearest neighbor distances between analysis and assessment sets are
41+
#' still lower than those between training and prediction locations.
42+
#'
43+
#' @return A tibble with classes `spatial_nndm_cv`, `spatial_rset`, `rset`,
44+
#' `tbl_df`, `tbl`, and `data.frame`. The results include a column for the
45+
#' data split objects and an identification variable `id`.
46+
#'
47+
#' @references
48+
#' C. Milà, J. Mateu, E. Pebesma, and H. Meyer. 2022. "Nearest Neighbour
49+
#' Distance Matching Leave-One-Out Cross-Validation for map validation." Methods
50+
#' in Ecology and Evolution 2022:13, pp 1304– 1316.
51+
#' doi: 10.1111/2041-210X.13851.
52+
#'
53+
#' H. Meyer and E. Pebesma. 2022. "Machine learning-based global maps of
54+
#' ecological variables and the challenge of assessing them."
55+
#' Nature Communications 13, pp 2208. doi: 10.1038/s41467-022-29838-9.
56+
#'
57+
#' @examplesIf rlang::is_installed("modeldata")
58+
#' data(ames, package = "modeldata")
59+
#' ames_sf <- sf::st_as_sf(ames, coords = c("Longitude", "Latitude"), crs = 4326)
60+
#'
61+
#' # Using a small subset of the data, to make the example run faster:
62+
#' spatial_nndm_cv(ames_sf[1:200, ], ames_sf[2001:2200, ])
63+
#'
64+
#' @export
65+
spatial_nndm_cv <- function(data, prediction_sites, ...,
66+
autocorrelation_range = NULL,
67+
prediction_sample_size = 1000,
68+
min_analysis_proportion = 0.5) {
69+
# Data validation: check that all dots are used,
70+
# that data and prediction_sites are sf objects,
71+
# that data has a CRS and s2 is enabled if necessary
72+
rlang::check_dots_used()
73+
74+
standard_checks(data, "`spatial_nndm_cv()`", rlang::current_env())
75+
if (!is_sf(prediction_sites)) {
76+
rlang::abort(
77+
c(
78+
glue::glue("`spatial_nndm_cv()` currently only supports `sf` objects."),
79+
i = "Try converting `prediction_sites` to an `sf` object via `sf::st_as_sf()`."
80+
)
81+
)
82+
}
83+
84+
# sf::st_distance won't reproject automatically, so if prediction_sites
85+
# isn't already aligned with data, reproject coordinates to prevent
86+
# distance calculations from failing
87+
if (!isTRUE(sf::st_crs(prediction_sites) == sf::st_crs(data))) {
88+
rlang::warn(
89+
c(
90+
"Reprojecting `prediction_sites` to match the CRS of `data`.",
91+
i = "Reproject `prediction_sites` and `data` to share a CRS to avoid this warning."
92+
)
93+
)
94+
if (is.na(sf::st_crs(prediction_sites))) {
95+
prediction_sites <- sf::st_set_crs(prediction_sites, sf::st_crs(data))
96+
} else {
97+
prediction_sites <- sf::st_transform(prediction_sites, sf::st_crs(data))
98+
}
99+
}
100+
101+
# Attributes that will be attached to the rset object
102+
# Importantly this is before we sample prediction_sites
103+
# or compute autocorrelation_range,
104+
# primarily for compatibility with rsample::reshuffle_rset()
105+
cv_att <- list(
106+
prediction_sites = prediction_sites,
107+
prediction_sample_size = prediction_sample_size,
108+
autocorrelation_range = autocorrelation_range,
109+
min_analysis_proportion = min_analysis_proportion,
110+
...
111+
)
112+
113+
######## Actual processing begins here ########
114+
# "If any element of `prediction_sites` is not a single point,
115+
# then points are sampled from within the bounding box of `prediction_sites`"
116+
# Because an sf object can contain multiple geometry types,
117+
# we check both for length > 1 (in order to avoid the "condition has length"
118+
# error) and to see if the input is already only points
119+
pred_geometry <- unique(sf::st_geometry_type(prediction_sites))
120+
if (length(pred_geometry) > 1 || pred_geometry != "POINT") {
121+
prediction_sites <- sf::st_sample(
122+
x = sf::st_as_sfc(sf::st_bbox(prediction_sites)),
123+
size = prediction_sample_size,
124+
...
125+
)
126+
}
127+
128+
# Set autocorrelation range, if not specified, to be the distance between
129+
# the bottom-left and upper-right corners of prediction_sites --
130+
# the idea being that this is the maximum relevant distance for
131+
# autocorrelation, and there's limited harm in assuming too long a range
132+
# (at least, versus too short)
133+
#
134+
# We do this after sampling for 1:1 compatibility with CAST
135+
if (is.null(autocorrelation_range)) {
136+
bbox <- sf::st_bbox(prediction_sites)
137+
138+
autocorrelation_range <- sf::st_distance(
139+
sf::st_as_sf(
140+
data.frame(
141+
lon = bbox[c("xmin", "xmax")],
142+
lat = bbox[c("ymin", "ymax")]
143+
),
144+
coords = c("lon", "lat"),
145+
crs = sf::st_crs(prediction_sites)
146+
)
147+
)[2]
148+
}
149+
150+
dist_to_nn_prediction <- apply(
151+
sf::st_distance(prediction_sites, data),
152+
1,
153+
min
154+
)
155+
156+
distance_matrix <- sf::st_distance(data)
157+
158+
# We've enforced that prediction_sites and data are in the same CRS;
159+
# therefore nearest_neighbors and distance_matrix are in the same units
160+
# Force autocorrelation_range into the same units:
161+
units(autocorrelation_range) <- units(distance_matrix)
162+
163+
# We're guaranteed to be working in one set of units now,
164+
# which means we should be able to drop units entirely at this point
165+
# (which should make some of the logic here easier)
166+
units(autocorrelation_range) <- NULL
167+
units(distance_matrix) <- NULL
168+
169+
diag(distance_matrix) <- NA
170+
dist_to_nn_training <- apply(distance_matrix, 1, min, na.rm = TRUE)
171+
172+
current_neighbor <- list(
173+
distance = min(dist_to_nn_training),
174+
row = which.min(dist_to_nn_training)[1]
175+
)
176+
current_neighbor$col <- which.min(distance_matrix[current_neighbor$row, ])
177+
178+
n_training <- nrow(data)
179+
180+
# Core loop: try to match the empirical nearest neighbor distribution curves
181+
# (adjusting the training:training curve to that of prediction:training)
182+
while (current_neighbor$distance <= autocorrelation_range) {
183+
# Proportion of training data with a neighbor in training
184+
# closer than current_neighbor$distance if we removed one additional point
185+
# (hence 1 / n_training)
186+
prop_close_training <-
187+
mean(dist_to_nn_training <= current_neighbor$distance) - (1 / n_training)
188+
# Proportion of prediction data with a neighbor in training data
189+
# closer than current_neighbor$distance
190+
prop_close_prediction <- mean(
191+
dist_to_nn_prediction <= current_neighbor$distance
192+
)
193+
194+
# How much data remains in analysis sets?
195+
prop_remaining <- sum(
196+
!is.na(distance_matrix[current_neighbor$row, ])
197+
) / n_training
198+
199+
if ((prop_close_training >= prop_close_prediction) &
200+
(prop_remaining > min_analysis_proportion)) {
201+
202+
# Remove nearest neighbors from analysis sets until the % of points with
203+
# an NN in analysis at distance X in analysis ~= the % of points
204+
# in predict with NN in train at distance X
205+
distance_matrix[current_neighbor$row, current_neighbor$col] <- NA
206+
207+
dist_to_nn_training <- apply(distance_matrix, 1, min, na.rm = TRUE)
208+
209+
# Then update "distance X" to be the next nearest neighbor
210+
#
211+
# We just set the distance at current_neighbor to NA,
212+
# so using >= won't just select the same neighbor over and over again
213+
current_neighbor <- find_next_neighbor(
214+
current_neighbor,
215+
dist_to_nn_training,
216+
distance_matrix,
217+
equal_distance_ok = TRUE
218+
)
219+
} else {
220+
# If prop_close_training < prop_close_prediction,
221+
# we don't need to remove the current point;
222+
# as such, we need to find a distance >, rather than >=,
223+
# to the current neighbor
224+
# (or else we'd loop on this point forever)
225+
current_neighbor <- find_next_neighbor(
226+
current_neighbor,
227+
dist_to_nn_training,
228+
distance_matrix,
229+
equal_distance_ok = FALSE
230+
)
231+
}
232+
233+
if (!any(dist_to_nn_training > current_neighbor$distance)) {
234+
break
235+
}
236+
}
237+
238+
indices <- purrr::map(
239+
seq_len(nrow(distance_matrix)),
240+
function(i) {
241+
list(
242+
analysis = which(!is.na(distance_matrix[i, ])),
243+
assessment = i
244+
)
245+
}
246+
)
247+
248+
split_objs <- purrr::map(
249+
indices,
250+
make_splits,
251+
data = data,
252+
class = c("spatial_nndm_split", "spatial_rsplit")
253+
)
254+
255+
new_rset(
256+
splits = split_objs,
257+
ids = names0(length(split_objs), "Fold"),
258+
attrib = cv_att,
259+
subclass = c("spatial_nndm_cv", "spatial_rset", "rset")
260+
)
261+
}
262+
263+
find_next_neighbor <- function(current_neighbor, dist_to_nn_training, distance_matrix, equal_distance_ok = FALSE) {
264+
operator <- if (equal_distance_ok) `>=` else `>`
265+
current_neighbor$distance <- min(dist_to_nn_training[operator(dist_to_nn_training, current_neighbor$distance)])
266+
current_neighbor$row <- which(dist_to_nn_training == current_neighbor$distance)[1]
267+
current_neighbor$col <- which(distance_matrix[current_neighbor$row, ] == current_neighbor$distance)
268+
current_neighbor
269+
}

0 commit comments

Comments
 (0)