|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT License. |
| 3 | + |
| 4 | +"""Northeastern China Crop Map Dataset.""" |
| 5 | + |
| 6 | +import glob |
| 7 | +import os |
| 8 | +from collections.abc import Iterable |
| 9 | +import pathlib |
| 10 | +from typing import Any, Callable, Optional, Union |
| 11 | + |
| 12 | +import matplotlib.pyplot as plt |
| 13 | +import torch |
| 14 | +from matplotlib.figure import Figure |
| 15 | +from rasterio.crs import CRS |
| 16 | + |
| 17 | +from .geo import RasterDataset |
| 18 | +from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive |
| 19 | + |
| 20 | + |
| 21 | +class NCCM(RasterDataset): |
| 22 | + """The Northeastern China Crop Map Dataset. |
| 23 | +
|
| 24 | + Link: https://www.nature.com/articles/s41597-021-00827-9 |
| 25 | +
|
| 26 | + This dataset produced annual 10-m crop maps of the |
| 27 | + major crops (maize, soybean, and rice) |
| 28 | + in Northeast China from 2017 to 2019, using hierarchial mapping strategies, |
| 29 | + random forest classifiers, interpolated and |
| 30 | + smoothed 10-day Sentinel-2 time series data and |
| 31 | + optimized features from spectral, temporal and |
| 32 | + textural characteristics of the land surface. |
| 33 | + The resultant maps have high overall accuracies (OA) |
| 34 | + based on ground truth data. The dataset contains information |
| 35 | + specific to three years: 2017, 2018, 2019. |
| 36 | +
|
| 37 | + The dataset contains 5 classes: |
| 38 | +
|
| 39 | + 0. paddy rice |
| 40 | + 1. maize |
| 41 | + 2. soybean |
| 42 | + 3. others crops and lands |
| 43 | + 4. nodata |
| 44 | +
|
| 45 | + Dataset format: |
| 46 | +
|
| 47 | + * Three .TIF files containing the labels |
| 48 | + * JavaScript code to download images from the dataset. |
| 49 | +
|
| 50 | + If you use this dataset in your research, please cite the following paper: |
| 51 | +
|
| 52 | + * https://doi.org/10.1038/s41597-021-00827-9 |
| 53 | +
|
| 54 | + .. versionadded:: 0.6 |
| 55 | + """ |
| 56 | + |
| 57 | + filename_regex = r"CDL(?P<year>\d{4})_clip" |
| 58 | + filename_glob = "CDL*.*" |
| 59 | + zipfile_glob = "13090442.zip" |
| 60 | + |
| 61 | + date_format = "%Y" |
| 62 | + is_image = False |
| 63 | + url = "https://figshare.com/ndownloader/articles/13090442/versions/1" |
| 64 | + md5 = "eae952f1b346d7e649d027e8139a76f5" |
| 65 | + |
| 66 | + cmap = { |
| 67 | + 0: (0, 255, 0, 255), |
| 68 | + 1: (255, 0, 0, 255), |
| 69 | + 2: (255, 255, 0, 255), |
| 70 | + 3: (128, 128, 128, 255), |
| 71 | + 15: (255, 255, 255, 255), |
| 72 | + } |
| 73 | + |
| 74 | + def __init__( |
| 75 | + self, |
| 76 | + paths: Union[pathlib.Path, str, Iterable[Union[pathlib.Path, str]]] = "data", |
| 77 | + crs: Optional[CRS] = None, |
| 78 | + res: Optional[float] = None, |
| 79 | + transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None, |
| 80 | + cache: bool = True, |
| 81 | + download: bool = False, |
| 82 | + checksum: bool = False, |
| 83 | + ) -> None: |
| 84 | + """Initialize a new dataset. |
| 85 | +
|
| 86 | + Args: |
| 87 | + paths: one or more root directories to search or files to load |
| 88 | + crs: :term:`coordinate reference system (CRS)` to warp to |
| 89 | + (defaults to the CRS of the first file found) |
| 90 | + res: resolution of the dataset in units of CRS |
| 91 | + (defaults to the resolution of the first file found) |
| 92 | + transforms: a function/transform that takes an input sample |
| 93 | + and returns a transformed version |
| 94 | + cache: if True, cache file handle to speed up repeated sampling |
| 95 | + download: if True, download dataset and store it in the root directory |
| 96 | + checksum: if True, check the MD5 after downloading files (may be slow) |
| 97 | +
|
| 98 | + Raises: |
| 99 | + DatasetNotFoundError: If dataset is not found and *download* is False. |
| 100 | + """ |
| 101 | + self.paths = paths |
| 102 | + self.download = download |
| 103 | + self.checksum = checksum |
| 104 | + self.ordinal_map = torch.full((max(self.cmap.keys()) + 1,), 4, dtype=self.dtype) |
| 105 | + self.ordinal_cmap = torch.zeros((5, 4), dtype=torch.uint8) |
| 106 | + |
| 107 | + self._verify() |
| 108 | + super().__init__(paths, crs, res, transforms=transforms, cache=cache) |
| 109 | + |
| 110 | + for i, (k, v) in enumerate(self.cmap.items()): |
| 111 | + self.ordinal_map[k] = i |
| 112 | + self.ordinal_cmap[i] = torch.tensor(v) |
| 113 | + |
| 114 | + def __getitem__(self, query: BoundingBox) -> dict[str, Any]: |
| 115 | + """Retrieve mask and metadata indexed by query. |
| 116 | +
|
| 117 | + Args: |
| 118 | + query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index |
| 119 | +
|
| 120 | + Returns: |
| 121 | + sample of mask and metadata at that index |
| 122 | +
|
| 123 | + Raises: |
| 124 | + IndexError: if query is not found in the index |
| 125 | + """ |
| 126 | + sample = super().__getitem__(query) |
| 127 | + sample["mask"] = self.ordinal_map[sample["mask"]] |
| 128 | + return sample |
| 129 | + |
| 130 | + def _verify(self) -> None: |
| 131 | + """Verify the integrity of the dataset.""" |
| 132 | + # Check if the extracted files already exist |
| 133 | + if self.files: |
| 134 | + return |
| 135 | + |
| 136 | + # Check if the zip file has already been downloaded |
| 137 | + assert isinstance(self.paths, (pathlib.Path, str)) |
| 138 | + pathname = os.path.join(self.paths, "**", self.zipfile_glob) |
| 139 | + if glob.glob(pathname, recursive=True): |
| 140 | + self._extract() |
| 141 | + return |
| 142 | + |
| 143 | + # Check if the user requested to download the dataset |
| 144 | + if not self.download: |
| 145 | + raise DatasetNotFoundError(self) |
| 146 | + |
| 147 | + # Download the dataset |
| 148 | + self._download() |
| 149 | + self._extract() |
| 150 | + |
| 151 | + def _download(self) -> None: |
| 152 | + """Download the dataset.""" |
| 153 | + filename = "13090442.zip" |
| 154 | + download_url( |
| 155 | + self.url, self.paths, filename, md5=self.md5 if self.checksum else None |
| 156 | + ) |
| 157 | + |
| 158 | + def _extract(self) -> None: |
| 159 | + """Extract the dataset.""" |
| 160 | + assert isinstance(self.paths, (pathlib.Path, str)) |
| 161 | + pathname = os.path.join(self.paths, "**", self.zipfile_glob) |
| 162 | + extract_archive(glob.glob(pathname, recursive=True)[0], self.paths) |
| 163 | + |
| 164 | + def plot( |
| 165 | + self, |
| 166 | + sample: dict[str, Any], |
| 167 | + show_titles: bool = True, |
| 168 | + suptitle: Optional[str] = None, |
| 169 | + ) -> Figure: |
| 170 | + """Plot a sample from the dataset. |
| 171 | +
|
| 172 | + Args: |
| 173 | + sample: a sample returned by :meth:`NCCM.__getitem__` |
| 174 | + show_titles: flag indicating whether to show titles above each panel |
| 175 | + suptitle: optional string to use as a suptitle |
| 176 | +
|
| 177 | + Returns: |
| 178 | + a matplotlib Figure with the rendered sample |
| 179 | + """ |
| 180 | + mask = sample["mask"].squeeze() |
| 181 | + ncols = 1 |
| 182 | + |
| 183 | + showing_predictions = "prediction" in sample |
| 184 | + if showing_predictions: |
| 185 | + pred = sample["prediction"].squeeze() |
| 186 | + ncols = 2 |
| 187 | + |
| 188 | + fig, axs = plt.subplots( |
| 189 | + nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False |
| 190 | + ) |
| 191 | + |
| 192 | + axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none") |
| 193 | + axs[0, 0].axis("off") |
| 194 | + |
| 195 | + if show_titles: |
| 196 | + axs[0, 0].set_title("Mask") |
| 197 | + |
| 198 | + if showing_predictions: |
| 199 | + axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none") |
| 200 | + axs[0, 1].axis("off") |
| 201 | + if show_titles: |
| 202 | + axs[0, 1].set_title("Prediction") |
| 203 | + |
| 204 | + if suptitle is not None: |
| 205 | + plt.suptitle(suptitle) |
| 206 | + |
| 207 | + return fig |
0 commit comments