Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

261 area selector for url adaptor to accept kwargs #262

Merged
merged 30 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion cads_adaptors/adaptors/url.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,11 @@ def retrieve_list_of_results(self, request: dict[str, Any]) -> list[str]:
paths = url_tools.try_download(urls, context=self.context, **download_kwargs)

if self.area is not None:
paths = area_selector.area_selector_paths(paths, self.area, self.context)
paths = area_selector.area_selector_paths(
paths,
self.area,
self.context,
**self.config.get("post_processing_kwargs", {}),
)

return paths
203 changes: 142 additions & 61 deletions cads_adaptors/tools/area_selector.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,58 @@
import os
from copy import deepcopy
from typing import Any, Callable, Type

import dask
import numpy as np
import xarray as xr
from earthkit import data
from earthkit.transforms import tools as eka_tools

from cads_adaptors.adaptors import Context
from cads_adaptors.exceptions import InvalidRequest
from cads_adaptors.tools import adaptor_tools, convertors


def area_to_checked_dictionary(area: list[float | int]) -> dict[str, float | int]:
north, east, south, west = area
if north < south:
south, north = north, south
return {"north": north, "east": east, "south": south, "west": west}


def incompatible_area_error(
dim_key: str,
start: float,
end: float,
coord_range: list,
start: float | int,
end: float | int,
coord_range: list[float | int],
context: Context = Context(),
thisError=ValueError,
):
thisException: Type[Exception] = InvalidRequest,
) -> None:
error_message = (
"Your area selection is not yet compatible with this dataset.\n"
"Your area selection is not compatible with this dataset.\n"
f"Range selection for {dim_key}: [{start}, {end}].\n"
f"Coord range from dataset: {coord_range}"
)
context.add_user_visible_error(error_message)
raise thisError(error_message)
raise thisException(error_message)


def points_inside_range(points, point_range, how=any):
def points_inside_range(
points: list[float | int],
point_range: list[float | int],
how: Callable[[list[bool]], bool] = any,
) -> bool:
return how(
[point >= point_range[0] and point <= point_range[1] for point in points]
)


def wrap_longitudes(
dim_key: str,
start: float,
end: float,
coord_range: list,
start: float | int,
end: float | int,
coord_range: list[float | int],
context: Context = Context(),
) -> list:
) -> list[slice]:
start_in = deepcopy(start)
end_in = deepcopy(end)

Expand Down Expand Up @@ -88,7 +101,7 @@ def get_dim_slices(
context: Context = Context(),
longitude: bool = False,
precision: int = 2,
) -> list:
) -> list[slice]:
da_coord = ds[dim_key]

ascending = bool(da_coord[0] < da_coord[1]) # True = ascending, False = descending
Expand Down Expand Up @@ -122,67 +135,70 @@ def get_dim_slices(

# A final check that there is at least an overlap
if not points_inside_range([start, end], coord_range):
incompatible_area_error(
dim_key, start, end, coord_range, context, thisError=NotImplementedError
)
incompatible_area_error(dim_key, start, end, coord_range, context)

return [slice(start, end)]


def area_selector(
infile: str,
ds: xr.Dataset,
area: list[float | int] | dict[str, float | int] = [+90, -180, -90, +180],
context: Context = Context(),
area: list = [-90, -180, -90, +180],
to_xarray_kwargs: dict = dict(),
**kwargs,
):
north, east, south, west = area

# open object as earthkit data object
ek_d = data.from_source("file", infile)
**_kwargs: dict[str, Any],
) -> xr.Dataset:
if isinstance(area, list):
area = area_to_checked_dictionary(area)

ds = ek_d.to_xarray(**to_xarray_kwargs)
# Take a copy as they will be updated herein
kwargs = deepcopy(_kwargs)

spatial_info = eka_tools.get_spatial_info(ds)
spatial_info = eka_tools.get_spatial_info(
ds,
**{k: kwargs.pop(k) for k in ["lat_key", "lon_key"] if k in kwargs},
)
lon_key = spatial_info["lon_key"]
lat_key = spatial_info["lat_key"]

# Handle simple regular case:
if spatial_info["regular"]:
extra_kwargs: dict[str, Any] = {
k: kwargs.pop(k) for k in ["precision"] if k in kwargs
}
# Longitudes could return multiple slice in cases where the area wraps the "other side"
lon_slices = get_dim_slices(
ds,
lon_key,
east,
west,
area["east"],
area["west"],
context,
longitude=True,
**extra_kwargs,
)
# We assume that latitudes won't be wrapped
lat_slice = get_dim_slices(
ds,
lat_key,
south,
north,
context,
ds, lat_key, area["south"], area["north"], context, **extra_kwargs
)[0]

context.logger.debug(f"lat_slice: {lat_slice}\nlon_slices: {lon_slices}")
context.debug(f"lat_slice: {lat_slice}\nlon_slices: {lon_slices}")

sub_selections = []
for lon_slice in lon_slices:
sel_kwargs: dict[str, Any] = {
**kwargs, # Any remaining kwargs are for the sel command
spatial_info["lat_key"]: lat_slice,
spatial_info["lon_key"]: lon_slice,
}
sub_selections.append(
ds.sel(
**{
spatial_info["lat_key"]: lat_slice,
spatial_info["lon_key"]: lon_slice,
}
**sel_kwargs,
)
)
context.logger.debug(f"selections: {sub_selections}")
context.debug(f"selections: {sub_selections}")

ds_area = xr.concat(sub_selections, dim=lon_key)
context.logger.debug(f"ds_area: {ds_area}")
ds_area = xr.concat(
sub_selections, dim=lon_key, data_vars="minimal", coords="minimal"
)
context.debug(f"ds_area: {ds_area}")

# Ensure that there are no length zero dimensions
for dim in [lat_key, lon_key]:
Expand All @@ -203,33 +219,98 @@ def area_selector(
raise NotImplementedError("Area selection not available for data projection")


def area_selector_path(
infile: str,
area: list[float | int] | dict[str, float | int],
context: Context = Context(),
out_format: str | None = None,
target_dir: str | None = None,
area_selector_kwargs: dict[str, Any] = {},
open_datasets_kwargs: list[dict[str, Any]] | dict[str, Any] = {},
**kwargs: dict[str, Any],
) -> list[str]:
if isinstance(area, list):
area = area_to_checked_dictionary(area)

# Deduce input format from infile
in_ext = infile.split(".")[-1]
in_format = adaptor_tools.handle_data_format(in_ext)
if out_format is None:
out_format = in_format

# If target_dir not specified, then use the directory of the input file
if target_dir is None:
target_dir = os.path.dirname(infile)

# Set decode_times to False to avoid any unnecessary issues with decoding time coordinates
# Also set some auto-chunking
if isinstance(open_datasets_kwargs, list):
for _open_dataset_kwargs in open_datasets_kwargs:
_open_dataset_kwargs.setdefault("decode_times", False)
_open_dataset_kwargs.setdefault("chunks", -1)
else:
open_datasets_kwargs.setdefault("decode_times", False)
open_datasets_kwargs.setdefault("chunks", -1)

# open_kwargs =
ds_dict = convertors.open_file_as_xarray_dictionary(
infile,
context=context,
**{
**kwargs,
"open_datasets_kwargs": open_datasets_kwargs,
},
)

ds_area_dict = {
".".join(
[fname_tag, "area-subset"]
+ [str(area[a]) for a in ["north", "west", "south", "east"]]
): area_selector(ds, area=area, context=context, **area_selector_kwargs)
for fname_tag, ds in ds_dict.items()
}

# TODO: Consider using the write to file methods in convertors sub-module
out_paths = []
if out_format in ["nc", "netcdf"]:
for fname_tag, ds_area in ds_area_dict.items():
out_path = os.path.join(target_dir, f"{fname_tag}.nc")
for var in ds_area.variables:
ds_area[var].encoding.setdefault("_FillValue", None)
# Need to compute before writing to disk as dask loses too many jobs
ds_area.compute().to_netcdf(out_path)
out_paths.append(out_path)
else:
context.add_user_visible_error(
f"Cannot write area selected data to {out_format}, writing to netcdf."
)
for fname_tag, ds_area in ds_area_dict.items():
out_path = os.path.join(target_dir, f"{fname_tag}.nc")
for var in ds_area.variables:
ds_area[var].encoding.setdefault("_FillValue", None)
ds_area.compute().to_netcdf(out_path)
out_paths.append(out_path)

return out_paths


def area_selector_paths(
paths: list, area: list, context: Context, out_format: str = "netcdf"
):
paths: list[str],
area: list[float | int] | dict[str, float | int],
context: Context = Context(),
**kwargs: Any,
) -> list[str]:
with dask.config.set(scheduler="threads"):
# We try to select the area for all paths, if any fail we return the original paths
out_paths = []
for path in paths:
try:
ds_area = area_selector(path, context, area=area)
out_paths += area_selector_path(
path, area=area, context=context, **kwargs
)
except NotImplementedError:
context.logger.debug(
f"could not convert {path} to xarray; returning the original data"
)
out_paths.append(path)
else:
if out_format in ["nc", "netcdf"]:
out_fname = ".".join(
path.split(".")[:-1]
+ ["area-subset"]
+ [str(a) for a in area]
+ ["nc"]
)
context.logger.debug(f"out_fname: {out_fname}")
ds_area.compute().to_netcdf(out_fname)
out_paths.append(out_fname)
else:
raise NotImplementedError(
f"Output format not recognised {out_format}"
)
return out_paths
13 changes: 9 additions & 4 deletions cads_adaptors/tools/convertors.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def post_open_datasets_modifications(
def open_netcdf_as_xarray_dictionary(
netcdf_file: str,
context: Context = Context(),
open_datasets_kwargs: dict[str, Any] = {},
open_datasets_kwargs: list[dict[str, Any]] | dict[str, Any] = {},
post_open_datasets_kwargs: dict[str, Any] = {},
**kwargs,
) -> dict[str, xr.Dataset]:
Expand All @@ -529,9 +529,14 @@ def open_netcdf_as_xarray_dictionary(
"""
fname, _ = os.path.splitext(os.path.basename(netcdf_file))

assert isinstance(
open_datasets_kwargs, dict
), "open_datasets_kwargs must be a dictionary for netCDF"
if isinstance(open_datasets_kwargs, list):
assert (
len(open_datasets_kwargs) == 1
), "Only one set of open_datasets_kwargs allowed for netCDF"
open_datasets_kwargs = open_datasets_kwargs[0]
assert isinstance(
open_datasets_kwargs, dict
), "open_datasets_kwargs must be a single dictionary for netCDF"
# This is to maintain some consistency with the grib file opening
open_datasets_kwargs = {
"chunks": DEFAULT_CHUNKS,
Expand Down
Loading
Loading