Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace _pre_retrieve with normalise_request for MultiAdaptors #211

Merged
merged 3 commits into from
Oct 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 2 additions & 29 deletions cads_adaptors/adaptors/cds.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,6 @@ def normalise_request(self, request: Request) -> Request:
self.input_request = deepcopy(request)
self.context.debug(f"Input request:\n{self.input_request}")

# Apply any pre-mapping modifications
working_request = deepcopy(request)

# Enforce the schema on the input request
schemas = self.schemas
if not isinstance(schemas, list):
Expand All @@ -135,12 +132,10 @@ def normalise_request(self, request: Request) -> Request:
if adaptor_schema := self.adaptor_schema:
schemas = schemas + [adaptor_schema]
for schema in schemas:
working_request = enforce.enforce(
working_request, schema, self.context.logger
)
request = enforce.enforce(request, schema, self.context.logger)

# Pre-mapping modifications
working_request = self.pre_mapping_modifications(working_request)
working_request = self.pre_mapping_modifications(deepcopy(request))

# If specified by the adaptor, intersect the request with the constraints.
# The intersected_request is a list of requests
Expand Down Expand Up @@ -224,28 +219,6 @@ def set_download_format(self, download_format, default_download_format="zip"):
def get_licences(self, request: Request) -> list[tuple[str, int]]:
return self.licences

# TODO: replace call to _pre_retrieve with normalise_request
# Still used in CamsSolarRadiationTimeseriesAdaptor
def _pre_retrieve(self, request: Request, default_download_format="zip"):
self.input_request = deepcopy(request)
self.context.debug(f"Input request:\n{self.input_request}")
self.receipt = request.pop("receipt", False)

# Extract post-process steps from the request before mapping:
self.post_process_steps = request.pop("post_process", [])
self.context.debug(
f"Post-process steps extracted from request:\n{self.post_process_steps}"
)

self.mapped_request = self.apply_mapping(request) # type: ignore

self.download_format = self.mapped_request.pop(
"download_format", default_download_format
)
self.context.debug(
f"Request mapped to (collection_id={self.collection_id}):\n{self.mapped_request}"
)

def pp_mapping(self, in_pp_config: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Map the post-process steps from the request to the correct functions."""
from cads_adaptors.tools.post_processors import pp_config_mapping
Expand Down
101 changes: 55 additions & 46 deletions cads_adaptors/adaptors/multi.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from copy import deepcopy
from typing import Any

from cads_adaptors import AbstractCdsAdaptor, mapping
from cads_adaptors.adaptors import Request
from cads_adaptors.exceptions import InvalidRequest, MultiAdaptorNoDataError
from cads_adaptors.exceptions import MultiAdaptorNoDataError
from cads_adaptors.tools import adaptor_tools
from cads_adaptors.tools.general import ensure_list


Expand Down Expand Up @@ -80,16 +80,26 @@ def split_adaptors(

return sub_adaptors

def _pre_retrieve(self, request, default_download_format="zip"):
self.input_request = deepcopy(request)
self.receipt = request.pop("receipt", False)
self.mapped_request = mapping.apply_mapping(request, self.mapping)
self.download_format = self.mapped_request.pop(
"download_format", default_download_format
)
def pre_mapping_modifications(self, request: dict[str, Any]) -> dict[str, Any]:
request = super().pre_mapping_modifications(request)

download_format = request.pop("download_format", "zip")
self.set_download_format(download_format)

return request

def retrieve_list_of_results(self, request: Request) -> list[str]:
self._pre_retrieve(request, default_download_format="zip")
request = self.normalise_request(request)
# TODO: handle lists of requests, normalise_request has the power to implement_constraints
# which produces a list of complete hypercube requests.
try:
assert len(self.mapped_requests) == 1
except AssertionError:
self.context.add_user_visible_log(
f"WARNING: More than one request was mapped: {self.mapped_requests}, "
f"returning the first one only:\n{self.mapped_requests[0]}"
)
self.mapped_request = self.mapped_requests[0]

self.context.add_stdout(f"MultiAdaptor, full_request: {self.mapped_request}")

Expand Down Expand Up @@ -122,57 +132,54 @@ def convert_format(self, *args, **kwargs) -> list[str]:

return convert_format(*args, **kwargs)

def _pre_retrieve(self, request, default_download_format="zip"):
self.input_request = deepcopy(request)
self.receipt = request.pop("receipt", False)

# Intersect constraints
if self.config.get("intersect_constraints", False):
requests_after_intersection = self.intersect_constraints(request)
if len(requests_after_intersection) == 0:
msg = "Error: no intersection with the constraints."
raise InvalidRequest(msg)
else:
requests_after_intersection = [request]

self.mapped_requests_pieces = []
for request_piece_after_intersection in requests_after_intersection:
self.mapped_requests_pieces.append(
mapping.apply_mapping(request_piece_after_intersection, self.mapping)
)
def pre_mapping_modifications(self, request: dict[str, Any]) -> dict[str, Any]:
"""Implemented in normalise_request, before the mapping is applied."""
request = super().pre_mapping_modifications(request)

self.download_format = self.mapped_requests_pieces[0].pop(
"download_format", default_download_format
# TODO: Remove legacy syntax all together
data_format = request.pop("format", "grib")
data_format = request.pop("data_format", data_format)

# Account from some horribleness from the legacy system:
if data_format.lower() in ["netcdf.zip", "netcdf_zip", "netcdf4.zip"]:
data_format = "netcdf"
request.setdefault("download_format", "zip")

default_download_format = "as_source"
download_format = request.pop("download_format", default_download_format)
self.set_download_format(
download_format, default_download_format=default_download_format
)

# Apply any mapping
mapped_formats = self.apply_mapping({"data_format": data_format})
# TODO: Add this extra mapping to apply_mapping?
self.data_format = adaptor_tools.handle_data_format(
mapped_formats["data_format"]
)
return request

def retrieve_list_of_results(self, request: Request) -> list[str]:
"""For MultiMarsCdsAdaptor we just want to apply mapping from each adaptor."""
import dask

from cads_adaptors.adaptors.mars import execute_mars
from cads_adaptors.tools import adaptor_tools

# Format of data files, grib or netcdf
data_format = request.pop("format", "grib")
data_format = request.pop("data_format", data_format)
data_format = adaptor_tools.handle_data_format(data_format)

# Account from some horribleness from teh legacy system:
if data_format.lower() in ["netcdf.zip", "netcdf_zip", "netcdf4.zip"]:
data_format = "netcdf"
request.setdefault("download_format", "zip")

self._pre_retrieve(request, default_download_format="as_source")
request = self.normalise_request(request)
# This will apply any top level multi-adaptor mapping, currently not used but could potentially
# be useful to reduce the repetitive config in each sub-adaptor of adaptor.json

mapped_requests = []
# self.mapped_requests contains the schema-checked, intersected and (top-level mapping) mapped request
self.context.add_stdout(
f"MultiMarsCdsAdaptor, full_request: {self.mapped_requests_pieces}"
f"MultiMarsCdsAdaptor, full_request: {self.mapped_requests}"
)

# We now split the mapped_request into sub-adaptors
mapped_requests = []
for adaptor_tag, adaptor_desc in self.config["adaptors"].items():
this_adaptor = adaptor_tools.get_adaptor(adaptor_desc, self.form)
this_values = adaptor_desc.get("values", {})
for mapped_request_piece in self.mapped_requests_pieces:
for mapped_request_piece in self.mapped_requests:
this_request = self.split_request(
mapped_request_piece, this_values, **this_adaptor.config
)
Expand All @@ -191,7 +198,9 @@ def retrieve_list_of_results(self, request: Request) -> list[str]:
result = execute_mars(mapped_requests, context=self.context, config=self.config)

with dask.config.set(scheduler="threads"):
paths = self.convert_format(result, data_format, self.context, self.config)
paths = self.convert_format(
result, self.data_format, self.context, self.config
)

if len(paths) > 1 and self.download_format == "as_source":
self.download_format = "zip"
Expand Down