Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix for kwargs handling in post_processors #185

Merged
merged 21 commits into from
Jul 29, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 27 additions & 10 deletions cads_adaptors/adaptors/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,16 +137,23 @@ def get_licences(self, request: Request) -> list[tuple[str, int]]:
# and currently only implemented for retrieve methods
def _pre_retrieve(self, request: Request, default_download_format="zip"):
self.input_request = deepcopy(request)
self.context.debug(f"Input request:\n{self.input_request}")
self.receipt = request.pop("receipt", False)

# Extract post-process steps from the request before mapping:
self.post_process_steps = self.pp_mapping(request.pop("post_process", []))
self.context.debug(
f"Post-process steps extracted from request:\n{self.post_process_steps}"
)

self.mapped_request = self.apply_mapping(request) # type: ignore

self.download_format = self.mapped_request.pop(
"download_format", default_download_format
)
self.context.debug(
f"Request mapped to (collection_id={self.collection_id}):\n{self.mapped_request}"
)

def pp_mapping(self, in_pp_config: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Map the post-process steps from the request to the correct functions."""
Expand All @@ -160,6 +167,9 @@ def pp_mapping(self, in_pp_config: list[dict[str, Any]]) -> list[dict[str, Any]]
def post_process(self, result: Any) -> dict[str, Any]:
"""Perform post-process steps on the retrieved data."""
for i, pp_step in enumerate(self.post_process_steps):
self.context.add_stdout(
f"Performing post-process step {i+1} of {len(self.post_process_steps)}: {pp_step}"
)
# TODO: pp_mapping should have ensured "method" is always present

if "method" not in pp_step:
Expand All @@ -168,9 +178,9 @@ def post_process(self, result: Any) -> dict[str, Any]:
)
continue

method_name = pp_step["method"]
method_name = pp_step.pop("method")
# TODO: Add extra condition to limit pps from dataset configurations
if hasattr(self, method_name):
if not hasattr(self, method_name):
self.context.add_user_visible_error(
message=f"Post-processor method '{method_name}' not available for this dataset"
)
Expand All @@ -180,21 +190,28 @@ def post_process(self, result: Any) -> dict[str, Any]:
# post processing is done on xarray objects,
# so on first pass we ensure result is opened as xarray
if i == 0:
post_processing_kwargs = self.config.get("post_processing_kwargs", {})

from cads_adaptors.tools.convertors import (
open_result_as_xarray_dictionary,
)

post_processing_kwargs = self.config.get("post_processing_kwargs", {})

open_datasets_kwargs = post_processing_kwargs.get(
"open_datasets_kwargs", {}
)
post_open_datasets_kwargs = post_processing_kwargs.get(
"post_open_datasets_kwargs", {}
)
self.context.add_stdout(
f"Opening result: {result} as xarray dictionary with kwargs:\n"
f"open_dataset_kwargs: {open_datasets_kwargs}\n"
f"post_open_datasets_kwargs: {post_open_datasets_kwargs}"
)
result = open_result_as_xarray_dictionary(
result,
context=self.context,
open_datasets_kwargs=post_processing_kwargs.get(
"open_datasets_kwargs", {}
),
post_open_kwargs=post_processing_kwargs.get(
"post_open_datasets_kwargs", {}
),
open_datasets_kwargs=open_datasets_kwargs,
post_open_datasets_kwargs=post_open_datasets_kwargs,
)

result = method(result, **pp_step)
Expand Down
51 changes: 22 additions & 29 deletions cads_adaptors/tools/convertors.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,11 @@ def convert_format(
) -> list[str]:
target_format = adaptor_tools.handle_data_format(target_format)
post_processing_kwargs = config.get("post_processing_kwargs", {})
open_datasets_kwargs: dict[str, Any] = post_processing_kwargs.get(
"open_datasets_kwargs", {}
)
post_open_kwargs: dict[str, Any] = post_processing_kwargs.get(
"post_open_datasets_kwargs", {}
)

# Keywords specific to writing to the target format
to_target_kwargs: dict[str, Any] = config.get(f"to_{target_format}_kwargs", {})

convertor: None | Callable = CONVERTORS.get(target_format, None)

if convertor is not None:
return convertor(
result,
context=context,
open_datasets_kwargs=open_datasets_kwargs,
post_open_kwargs=post_open_kwargs,
**to_target_kwargs,
)
return convertor(result, context=context, **post_processing_kwargs)

else:
message = (
Expand Down Expand Up @@ -240,15 +225,19 @@ def result_to_netcdf_files(
def result_to_netcdf_legacy_files(
result: Any,
context: Context = Context(),
command: str | list[str] = ["grib_to_netcdf", "-S", "param"],
filter_rules: str | None = None,
to_netcdf_legacy_kwargs: dict[str, Any] = {},
**kwargs,
) -> list[str]:
"""
Legacy grib_to_netcdf convertor, which will be marked as deprecated.
Can only accept a grib file, or list/dict of grib files as input.
Converts to netCDF3 only.
"""
command: str | list[str] = to_netcdf_legacy_kwargs.get(
"command", ["grib_to_netcdf", "-S", "param"]
)
filter_rules: str | None = to_netcdf_legacy_kwargs.get("filter", None)

context.add_user_visible_error(
"The 'netcdf_legacy' format is deprecated and no longer supported. "
"Users are encouraged to update workflows to use the updated, and CF compliant, 'netcdf' option."
Expand Down Expand Up @@ -383,24 +372,25 @@ def unknown_filetype_to_netcdf_files(
def grib_to_netcdf_files(
grib_file: str,
open_datasets_kwargs: None | dict[str, Any] | list[dict[str, Any]] = None,
post_open_kwargs: dict[str, Any] = {},
post_open_datasets_kwargs: dict[str, Any] = {},
to_netcdf_kwargs: dict[str, Any] = {},
context: Context = Context(),
**to_netcdf_kwargs,
**kwargs,
):
to_netcdf_kwargs.update(to_netcdf_kwargs.pop("to_netcdf_kwargs", {}))
to_netcdf_kwargs.update(kwargs.pop("to_netcdf_kwargs", {}))
grib_file = os.path.realpath(grib_file)

context.add_stdout(
f"Converting {grib_file} to netCDF files with:\n"
f"to_netcdf_kwargs: {to_netcdf_kwargs}\n"
f"open_datasets_kwargs: {open_datasets_kwargs}\n"
f"post_open_kwargs: {post_open_kwargs}\n"
f"post_open_datasets_kwargs: {post_open_datasets_kwargs}\n"
)

datasets = open_grib_file_as_xarray_dictionary(
grib_file,
open_datasets_kwargs=open_datasets_kwargs,
post_open_kwargs=post_open_kwargs,
post_open_datasets_kwargs=post_open_datasets_kwargs,
context=context,
)
# Fail here on empty lists so that error message is more informative
Expand All @@ -413,7 +403,9 @@ def grib_to_netcdf_files(
context.add_stderr(message=message)
raise RuntimeError(message)

out_nc_files = xarray_dict_to_netcdf(datasets, context=context, **to_netcdf_kwargs)
out_nc_files = xarray_dict_to_netcdf(
datasets, context=context, to_netcdf_kwargs=to_netcdf_kwargs
)

return out_nc_files

Expand All @@ -422,8 +414,9 @@ def xarray_dict_to_netcdf(
datasets: dict[str, xr.Dataset],
context: Context = Context(),
compression_options: str | dict[str, Any] = "default",
to_netcdf_kwargs: dict[str, Any] = {},
out_fname_prefix: str = "",
**to_netcdf_kwargs,
**kwargs,
) -> list[str]:
"""
Convert a dictionary of xarray datasets to netCDF files, where the key of the dictionary
Expand Down Expand Up @@ -603,7 +596,7 @@ def open_netcdf_as_xarray_dictionary(
netcdf_file: str,
context: Context = Context(),
open_datasets_kwargs: dict[str, Any] = {},
post_open_kwargs: dict[str, Any] = {},
post_open_datasets_kwargs: dict[str, Any] = {},
**kwargs,
) -> dict[str, xr.Dataset]:
"""
Expand All @@ -625,15 +618,15 @@ def open_netcdf_as_xarray_dictionary(
context.add_stdout(f"Opening {netcdf_file} with kwargs: {open_datasets_kwargs}")
datasets = {fname: xr.open_dataset(netcdf_file, **open_datasets_kwargs)}

datasets = post_open_datasets_modifications(datasets, **post_open_kwargs)
datasets = post_open_datasets_modifications(datasets, **post_open_datasets_kwargs)

return datasets


def open_grib_file_as_xarray_dictionary(
grib_file: str,
open_datasets_kwargs: None | dict[str, Any] | list[dict[str, Any]] = None,
post_open_kwargs: dict[str, Any] = {},
post_open_datasets_kwargs: dict[str, Any] = {},
context: Context = Context(),
**kwargs,
) -> dict[str, xr.Dataset]:
Expand Down Expand Up @@ -686,6 +679,6 @@ def open_grib_file_as_xarray_dictionary(
)
}

datasets = post_open_datasets_modifications(datasets, **post_open_kwargs)
datasets = post_open_datasets_modifications(datasets, **post_open_datasets_kwargs)

return datasets
4 changes: 2 additions & 2 deletions cads_adaptors/tools/post_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ def daily_reduce(
out_xarray_dict = {}
for in_tag, in_dataset in in_xarray_dict.items():
out_tag = f"{in_tag}_daily-{how}"
context.add_stdout(f"Daily reduction: {how} {kwargs}")
context.add_user_visible_log(f"Temporal reduction: {how} {kwargs}")
context.add_stdout(f"Daily reduction: {how} {kwargs}\n{in_dataset}")
context.add_user_visible_log(f"Daily reduction: {how} {kwargs}")
reduced_data = temporal.daily_reduce(
in_dataset,
how=how,
Expand Down