Skip to content

Commit f3a77fc

Browse files
author
Hitesh Tolani
committed
bugs fixes.
1 parent 926fdf4 commit f3a77fc

File tree

3 files changed

+218
-6
lines changed

3 files changed

+218
-6
lines changed

torchgeo/datasets/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@
7777
from .millionaid import MillionAID
7878
from .naip import NAIP
7979
from .nasa_marine_debris import NASAMarineDebris
80+
from .nccm import NCCM
8081
from .nlcd import NLCD
8182
from .openbuildings import OpenBuildings
8283
from .oscd import OSCD
@@ -116,6 +117,7 @@
116117
from .usavars import USAVars
117118
from .utils import (
118119
BoundingBox,
120+
DatasetNotFoundError,
119121
concat_samples,
120122
merge_samples,
121123
stack_samples,
@@ -167,6 +169,7 @@
167169
"Landsat8",
168170
"Landsat9",
169171
"NAIP",
172+
"NCCM",
170173
"NLCD",
171174
"OpenBuildings",
172175
"Sentinel",
@@ -253,4 +256,6 @@
253256
"random_grid_cell_assignment",
254257
"roi_split",
255258
"time_series_split",
259+
# Errors
260+
"DatasetNotFoundError",
256261
)

torchgeo/datasets/nccm.py

+207
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT License.
3+
4+
"""Northeastern China Crop Map Dataset."""
5+
6+
import glob
7+
import os
8+
from collections.abc import Iterable
9+
import pathlib
10+
from typing import Any, Callable, Optional, Union
11+
12+
import matplotlib.pyplot as plt
13+
import torch
14+
from matplotlib.figure import Figure
15+
from rasterio.crs import CRS
16+
17+
from .geo import RasterDataset
18+
from .utils import BoundingBox, DatasetNotFoundError, download_url, extract_archive
19+
20+
21+
class NCCM(RasterDataset):
22+
"""The Northeastern China Crop Map Dataset.
23+
24+
Link: https://www.nature.com/articles/s41597-021-00827-9
25+
26+
This dataset produced annual 10-m crop maps of the
27+
major crops (maize, soybean, and rice)
28+
in Northeast China from 2017 to 2019, using hierarchial mapping strategies,
29+
random forest classifiers, interpolated and
30+
smoothed 10-day Sentinel-2 time series data and
31+
optimized features from spectral, temporal and
32+
textural characteristics of the land surface.
33+
The resultant maps have high overall accuracies (OA)
34+
based on ground truth data. The dataset contains information
35+
specific to three years: 2017, 2018, 2019.
36+
37+
The dataset contains 5 classes:
38+
39+
0. paddy rice
40+
1. maize
41+
2. soybean
42+
3. others crops and lands
43+
4. nodata
44+
45+
Dataset format:
46+
47+
* Three .TIF files containing the labels
48+
* JavaScript code to download images from the dataset.
49+
50+
If you use this dataset in your research, please cite the following paper:
51+
52+
* https://doi.org/10.1038/s41597-021-00827-9
53+
54+
.. versionadded:: 0.6
55+
"""
56+
57+
filename_regex = r"CDL(?P<year>\d{4})_clip"
58+
filename_glob = "CDL*.*"
59+
zipfile_glob = "13090442.zip"
60+
61+
date_format = "%Y"
62+
is_image = False
63+
url = "https://figshare.com/ndownloader/articles/13090442/versions/1"
64+
md5 = "eae952f1b346d7e649d027e8139a76f5"
65+
66+
cmap = {
67+
0: (0, 255, 0, 255),
68+
1: (255, 0, 0, 255),
69+
2: (255, 255, 0, 255),
70+
3: (128, 128, 128, 255),
71+
15: (255, 255, 255, 255),
72+
}
73+
74+
def __init__(
75+
self,
76+
paths: Union[pathlib.Path, str, Iterable[Union[pathlib.Path, str]]] = "data",
77+
crs: Optional[CRS] = None,
78+
res: Optional[float] = None,
79+
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
80+
cache: bool = True,
81+
download: bool = False,
82+
checksum: bool = False,
83+
) -> None:
84+
"""Initialize a new dataset.
85+
86+
Args:
87+
paths: one or more root directories to search or files to load
88+
crs: :term:`coordinate reference system (CRS)` to warp to
89+
(defaults to the CRS of the first file found)
90+
res: resolution of the dataset in units of CRS
91+
(defaults to the resolution of the first file found)
92+
transforms: a function/transform that takes an input sample
93+
and returns a transformed version
94+
cache: if True, cache file handle to speed up repeated sampling
95+
download: if True, download dataset and store it in the root directory
96+
checksum: if True, check the MD5 after downloading files (may be slow)
97+
98+
Raises:
99+
DatasetNotFoundError: If dataset is not found and *download* is False.
100+
"""
101+
self.paths = paths
102+
self.download = download
103+
self.checksum = checksum
104+
self.ordinal_map = torch.full((max(self.cmap.keys()) + 1,), 4, dtype=self.dtype)
105+
self.ordinal_cmap = torch.zeros((5, 4), dtype=torch.uint8)
106+
107+
self._verify()
108+
super().__init__(paths, crs, res, transforms=transforms, cache=cache)
109+
110+
for i, (k, v) in enumerate(self.cmap.items()):
111+
self.ordinal_map[k] = i
112+
self.ordinal_cmap[i] = torch.tensor(v)
113+
114+
def __getitem__(self, query: BoundingBox) -> dict[str, Any]:
115+
"""Retrieve mask and metadata indexed by query.
116+
117+
Args:
118+
query: (minx, maxx, miny, maxy, mint, maxt) coordinates to index
119+
120+
Returns:
121+
sample of mask and metadata at that index
122+
123+
Raises:
124+
IndexError: if query is not found in the index
125+
"""
126+
sample = super().__getitem__(query)
127+
sample["mask"] = self.ordinal_map[sample["mask"]]
128+
return sample
129+
130+
def _verify(self) -> None:
131+
"""Verify the integrity of the dataset."""
132+
# Check if the extracted files already exist
133+
if self.files:
134+
return
135+
136+
# Check if the zip file has already been downloaded
137+
assert isinstance(self.paths, (pathlib.Path, str))
138+
pathname = os.path.join(self.paths, "**", self.zipfile_glob)
139+
if glob.glob(pathname, recursive=True):
140+
self._extract()
141+
return
142+
143+
# Check if the user requested to download the dataset
144+
if not self.download:
145+
raise DatasetNotFoundError(self)
146+
147+
# Download the dataset
148+
self._download()
149+
self._extract()
150+
151+
def _download(self) -> None:
152+
"""Download the dataset."""
153+
filename = "13090442.zip"
154+
download_url(
155+
self.url, self.paths, filename, md5=self.md5 if self.checksum else None
156+
)
157+
158+
def _extract(self) -> None:
159+
"""Extract the dataset."""
160+
assert isinstance(self.paths, (pathlib.Path, str))
161+
pathname = os.path.join(self.paths, "**", self.zipfile_glob)
162+
extract_archive(glob.glob(pathname, recursive=True)[0], self.paths)
163+
164+
def plot(
165+
self,
166+
sample: dict[str, Any],
167+
show_titles: bool = True,
168+
suptitle: Optional[str] = None,
169+
) -> Figure:
170+
"""Plot a sample from the dataset.
171+
172+
Args:
173+
sample: a sample returned by :meth:`NCCM.__getitem__`
174+
show_titles: flag indicating whether to show titles above each panel
175+
suptitle: optional string to use as a suptitle
176+
177+
Returns:
178+
a matplotlib Figure with the rendered sample
179+
"""
180+
mask = sample["mask"].squeeze()
181+
ncols = 1
182+
183+
showing_predictions = "prediction" in sample
184+
if showing_predictions:
185+
pred = sample["prediction"].squeeze()
186+
ncols = 2
187+
188+
fig, axs = plt.subplots(
189+
nrows=1, ncols=ncols, figsize=(ncols * 4, 4), squeeze=False
190+
)
191+
192+
axs[0, 0].imshow(self.ordinal_cmap[mask], interpolation="none")
193+
axs[0, 0].axis("off")
194+
195+
if show_titles:
196+
axs[0, 0].set_title("Mask")
197+
198+
if showing_predictions:
199+
axs[0, 1].imshow(self.ordinal_cmap[pred], interpolation="none")
200+
axs[0, 1].axis("off")
201+
if show_titles:
202+
axs[0, 1].set_title("Prediction")
203+
204+
if suptitle is not None:
205+
plt.suptitle(suptitle)
206+
207+
return fig

torchgeo/datasets/spacenet.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ class SpaceNet1(SpaceNet):
403403

404404
def __init__(
405405
self,
406-
root: str = "data",
406+
root: Union[pathlib.Path, str] = "data",
407407
image: str = "rgb",
408408
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
409409
download: bool = False,
@@ -518,7 +518,7 @@ class SpaceNet2(SpaceNet):
518518

519519
def __init__(
520520
self,
521-
root: str = "data",
521+
root: Union[pathlib.Path, str] = "data",
522522
image: str = "PS-RGB",
523523
collections: list[str] = [],
524524
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@@ -638,7 +638,7 @@ class SpaceNet3(SpaceNet):
638638

639639
def __init__(
640640
self,
641-
root: str = "data",
641+
root: Union[pathlib.Path, str] = "data",
642642
image: str = "PS-RGB",
643643
speed_mask: Optional[bool] = False,
644644
collections: list[str] = [],
@@ -888,7 +888,7 @@ class SpaceNet4(SpaceNet):
888888

889889
def __init__(
890890
self,
891-
root: str = "data",
891+
root: Union[pathlib.Path, str] = "data",
892892
image: str = "PS-RGBNIR",
893893
angles: list[str] = [],
894894
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
@@ -1188,7 +1188,7 @@ class SpaceNet6(SpaceNet):
11881188

11891189
def __init__(
11901190
self,
1191-
root: str = "data",
1191+
root: Union[pathlib.Path, str] = "data",
11921192
image: str = "PS-RGB",
11931193
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
11941194
download: bool = False,
@@ -1289,7 +1289,7 @@ class SpaceNet7(SpaceNet):
12891289

12901290
def __init__(
12911291
self,
1292-
root: str = "data",
1292+
root: Union[pathlib.Path, str] = "data",
12931293
split: str = "train",
12941294
transforms: Optional[Callable[[dict[str, Any]], dict[str, Any]]] = None,
12951295
download: bool = False,

0 commit comments

Comments
 (0)