Skip to content

Commit 5c7cd9c

Browse files
author
Taher Chegini
committed
MNT: Fix issues raised by pyright. [skip ci]
1 parent f37268a commit 5c7cd9c

File tree

2 files changed

+63
-130
lines changed

2 files changed

+63
-130
lines changed

src/pygridmet/cli.py

+26-4
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from __future__ import annotations
44

55
from pathlib import Path
6-
from typing import TYPE_CHECKING, TypeVar
6+
from typing import TYPE_CHECKING, Literal, TypeVar
77

88
import click
99
import geopandas as gpd
@@ -19,6 +19,24 @@
1919

2020
if TYPE_CHECKING:
2121
DFType = TypeVar("DFType", pd.DataFrame, gpd.GeoDataFrame)
22+
VARS = Literal[
23+
"pr",
24+
"rmax",
25+
"rmin",
26+
"sph",
27+
"srad",
28+
"th",
29+
"tmmn",
30+
"tmmx",
31+
"vs",
32+
"bi",
33+
"fm100",
34+
"fm1000",
35+
"erc",
36+
"etr",
37+
"pet",
38+
"vpd",
39+
]
2240

2341

2442
def parse_snow(target_df: pd.DataFrame) -> pd.DataFrame:
@@ -90,7 +108,7 @@ def cli() -> None:
90108
@ssl_opt
91109
def coords(
92110
fpath: Path,
93-
variables: list[str] | str | None = None,
111+
variables: list[VARS] | VARS | None = None,
94112
save_dir: str | Path = "clm_gridmet",
95113
disable_ssl: bool = False,
96114
) -> None:
@@ -139,7 +157,11 @@ def coords(
139157
if fname.exists():
140158
continue
141159
kwrgs = dict(zip(req_cols[1:], args))
142-
clm = gridmet.get_bycoords(**kwrgs, variables=variables, ssl=not disable_ssl)
160+
clm = gridmet.get_bycoords(
161+
**kwrgs,
162+
variables=variables,
163+
ssl=not disable_ssl,
164+
)
143165
clm.to_csv(fname, index=False)
144166
click.echo("Done.")
145167

@@ -151,7 +173,7 @@ def coords(
151173
@ssl_opt
152174
def geometry(
153175
fpath: Path,
154-
variables: list[str] | str | None = None,
176+
variables: list[VARS] | VARS | None = None,
155177
save_dir: str | Path = "clm_gridmet",
156178
disable_ssl: bool = False,
157179
) -> None:

src/pygridmet/pygridmet.py

+37-126
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@
2828
from shapely import MultiPolygon, Polygon
2929

3030
CRSTYPE = Union[int, str, pyproj.CRS]
31+
VARS = Literal[
32+
"pr",
33+
"rmax",
34+
"rmin",
35+
"sph",
36+
"srad",
37+
"th",
38+
"tmmn",
39+
"tmmx",
40+
"vs",
41+
"bi",
42+
"fm100",
43+
"fm1000",
44+
"erc",
45+
"etr",
46+
"pet",
47+
"vpd",
48+
]
3149

3250
DATE_FMT = "%Y-%m-%dT%H:%M:%SZ"
3351
MAX_CONN = 4
@@ -38,26 +56,7 @@
3856

3957
def _coord_urls(
4058
coord: tuple[float, float],
41-
variables: Iterable[
42-
Literal[
43-
"pr",
44-
"rmax",
45-
"rmin",
46-
"sph",
47-
"srad",
48-
"th",
49-
"tmmn",
50-
"tmmx",
51-
"vs",
52-
"bi",
53-
"fm100",
54-
"fm1000",
55-
"erc",
56-
"etr",
57-
"pet",
58-
"vpd",
59-
]
60-
],
59+
variables: Iterable[VARS],
6160
dates: list[tuple[pd.Timestamp, pd.Timestamp]],
6261
long_names: dict[str, str],
6362
) -> Generator[list[tuple[str, dict[str, dict[str, str]]]], None, None]:
@@ -129,7 +128,7 @@ def _by_coord(
129128
) -> pd.DataFrame:
130129
"""Get climate data for a coordinate and return as a DataFrame."""
131130
coords = (lon, lat)
132-
url_kwds = _coord_urls(coords, gridmet.variables, dates, gridmet.long_names)
131+
url_kwds = _coord_urls(coords, gridmet.variables, dates, gridmet.long_names) # pyright: ignore[reportArgumentType]
133132
retrieve = functools.partial(ar.retrieve_text, max_workers=MAX_CONN, ssl=ssl)
134133

135134
clm = pd.concat( # pyright: ignore[reportCallIssue]
@@ -167,45 +166,7 @@ def get_bycoords(
167166
dates: tuple[str, str] | int | list[int],
168167
coords_id: Sequence[str | int] | None = None,
169168
crs: CRSTYPE = 4326,
170-
variables: Iterable[
171-
Literal[
172-
"pr",
173-
"rmax",
174-
"rmin",
175-
"sph",
176-
"srad",
177-
"th",
178-
"tmmn",
179-
"tmmx",
180-
"vs",
181-
"bi",
182-
"fm100",
183-
"fm1000",
184-
"erc",
185-
"etr",
186-
"pet",
187-
"vpd",
188-
]
189-
]
190-
| Literal[
191-
"pr",
192-
"rmax",
193-
"rmin",
194-
"sph",
195-
"srad",
196-
"th",
197-
"tmmn",
198-
"tmmx",
199-
"vs",
200-
"bi",
201-
"fm100",
202-
"fm1000",
203-
"erc",
204-
"etr",
205-
"pet",
206-
"vpd",
207-
]
208-
| None = None,
169+
variables: Iterable[VARS] | VARS | None = None,
209170
snow: bool = False,
210171
snow_params: dict[str, float] | None = None,
211172
ssl: bool = True,
@@ -303,26 +264,7 @@ def get_bycoords(
303264

304265
def _gridded_urls(
305266
bounds: tuple[float, float, float, float],
306-
variables: Iterable[
307-
Literal[
308-
"pr",
309-
"rmax",
310-
"rmin",
311-
"sph",
312-
"srad",
313-
"th",
314-
"tmmn",
315-
"tmmx",
316-
"vs",
317-
"bi",
318-
"fm100",
319-
"fm1000",
320-
"erc",
321-
"etr",
322-
"pet",
323-
"vpd",
324-
]
325-
],
267+
variables: Iterable[VARS],
326268
dates: list[tuple[pd.Timestamp, pd.Timestamp]],
327269
long_names: dict[str, str],
328270
) -> tuple[list[str], list[dict[str, dict[str, str]]]]:
@@ -413,27 +355,34 @@ def _check_nans(
413355
def _download_urls(
414356
urls: list[str],
415357
kwds: list[dict[str, dict[str, str]]],
416-
clm_files: list[Path],
358+
clm_files: Sequence[Path],
417359
ssl: bool,
418360
long2abbr: dict[str, str],
419361
) -> xr.Dataset:
420362
"""Download the URLs and return the dataset."""
421-
clm_files_full = clm_files.copy()
363+
clm_files_full = list(clm_files)
364+
clm_files_ = clm_files_full.copy()
422365
clm = None
423366
# Sometimes the server returns NaNs, so we must check for that, remove
424367
# the files containing NaNs, and try again.
425368
for _ in range(N_RETRIES):
426-
clm_files = ogc.streaming_download(urls, kwds, clm_files, ssl=ssl, n_jobs=MAX_CONN)
427-
clm_files = [f for f in clm_files if f is not None]
369+
clm_files_ = ogc.streaming_download(
370+
urls,
371+
kwds,
372+
clm_files_,
373+
ssl=ssl,
374+
n_jobs=MAX_CONN,
375+
)
376+
clm_files_ = [f for f in clm_files_ if f is not None]
428377
try:
429378
# open_mfdataset can run into too many open files error so we use merge
430379
# https://docs.xarray.dev/en/stable/user-guide/io.html#reading-multi-file-datasets
431380
clm = xr.merge(_open_dataset(f) for f in clm_files_full).astype("f4")
432381
except ValueError:
433-
_ = [f.unlink() for f in clm_files]
382+
_ = [f.unlink() for f in clm_files_]
434383
continue
435384

436-
has_nans, urls, kwds, clm_files = _check_nans(clm, urls, kwds, clm_files, long2abbr)
385+
has_nans, urls, kwds, clm_files_ = _check_nans(clm, urls, kwds, clm_files_, long2abbr)
437386
if has_nans:
438387
clm = None
439388
continue
@@ -454,45 +403,7 @@ def get_bygeom(
454403
geometry: Polygon | MultiPolygon | tuple[float, float, float, float],
455404
dates: tuple[str, str] | int | list[int],
456405
crs: CRSTYPE = 4326,
457-
variables: Iterable[
458-
Literal[
459-
"pr",
460-
"rmax",
461-
"rmin",
462-
"sph",
463-
"srad",
464-
"th",
465-
"tmmn",
466-
"tmmx",
467-
"vs",
468-
"bi",
469-
"fm100",
470-
"fm1000",
471-
"erc",
472-
"etr",
473-
"pet",
474-
"vpd",
475-
]
476-
]
477-
| Literal[
478-
"pr",
479-
"rmax",
480-
"rmin",
481-
"sph",
482-
"srad",
483-
"th",
484-
"tmmn",
485-
"tmmx",
486-
"vs",
487-
"bi",
488-
"fm100",
489-
"fm1000",
490-
"erc",
491-
"etr",
492-
"pet",
493-
"vpd",
494-
]
495-
| None = None,
406+
variables: Iterable[VARS] | VARS | None = None,
496407
snow: bool = False,
497408
snow_params: dict[str, float] | None = None,
498409
ssl: bool = True,
@@ -551,7 +462,7 @@ def get_bygeom(
551462

552463
urls, kwds = _gridded_urls(
553464
_geometry.bounds, # pyright: ignore[reportGeneralTypeIssues]
554-
gridmet.variables,
465+
gridmet.variables, # pyright: ignore[reportArgumentType]
555466
gridmet.date_iterator,
556467
gridmet.long_names,
557468
)

0 commit comments

Comments
 (0)