|
| 1 | +import typing as T |
| 2 | + |
| 3 | +import yaml |
| 4 | + |
| 5 | +from cads_adaptors import AbstractCdsAdaptor |
| 6 | +from cads_adaptors.adaptors import Request |
| 7 | +from cads_adaptors.tools.general import ensure_list |
| 8 | +from cads_adaptors.tools.logger import logger |
| 9 | + |
| 10 | + |
| 11 | +class MultiAdaptor(AbstractCdsAdaptor): |
| 12 | + @staticmethod |
| 13 | + def split_request( |
| 14 | + full_request: Request, # User request |
| 15 | + this_values: T.Dict[str, T.Any], # key: [values] for the adaptor component |
| 16 | + **config: T.Any, |
| 17 | + ) -> Request: |
| 18 | + """ |
| 19 | + Basic request splitter, splits based on whether the values are relevant to |
| 20 | + the specific adaptor. |
| 21 | + More complex constraints may need a more detailed splitter. |
| 22 | + """ |
| 23 | + this_request = {} |
| 24 | + # loop over keys in this_values, i.e. the keys relevant to this_adaptor |
| 25 | + for key in list(this_values): |
| 26 | + # get request values for that key |
| 27 | + req_vals = full_request.get(key, []) |
| 28 | + # filter for values relevant to this_adaptor: |
| 29 | + these_vals = [ |
| 30 | + v for v in ensure_list(req_vals) if v in this_values.get(key, []) |
| 31 | + ] |
| 32 | + if len(these_vals) > 0: |
| 33 | + # if values then add to request |
| 34 | + this_request[key] = these_vals |
| 35 | + elif key in config.get("required_keys", []): |
| 36 | + # If a required key, then return an empty dictionary. |
| 37 | + # optional keys must be set in the adaptor.json via gecko |
| 38 | + return {} |
| 39 | + |
| 40 | + return this_request |
| 41 | + |
| 42 | + def retrieve(self, request: Request): |
| 43 | + from cads_adaptors.tools import adaptor_tools, download_tools |
| 44 | + |
| 45 | + download_format = request.pop("download_format", "zip") |
| 46 | + |
| 47 | + these_requests = {} |
| 48 | + exception_logs: T.Dict[str, str] = {} |
| 49 | + logger.debug(f"MultiAdaptor, full_request: {request}") |
| 50 | + for adaptor_tag, adaptor_desc in self.config["adaptors"].items(): |
| 51 | + this_adaptor = adaptor_tools.get_adaptor(adaptor_desc, self.form) |
| 52 | + this_values = adaptor_desc.get("values", {}) |
| 53 | + |
| 54 | + this_request = self.split_request(request, this_values, **self.config) |
| 55 | + logger.debug(f"MultiAdaptor, {adaptor_tag}, this_request: {this_request}") |
| 56 | + |
| 57 | + # TODO: check this_request is valid for this_adaptor, or rely on try? |
| 58 | + # i.e. split_request does NOT implement constraints. |
| 59 | + if len(this_request) > 0: |
| 60 | + this_request.setdefault("download_format", "list") |
| 61 | + these_requests[this_adaptor] = this_request |
| 62 | + |
| 63 | + results = [] |
| 64 | + for adaptor, req in these_requests.items(): |
| 65 | + try: |
| 66 | + this_result = adaptor.retrieve(req) |
| 67 | + except Exception: |
| 68 | + logger.debug(Exception) |
| 69 | + else: |
| 70 | + results += this_result |
| 71 | + |
| 72 | + # TODO: Add parallelistation via multiprocessing |
| 73 | + # # Allow a maximum of 2 parallel processes |
| 74 | + # import multiprocessing as mp |
| 75 | + |
| 76 | + # pool = mp.Pool(min(len(these_requests), 2)) |
| 77 | + |
| 78 | + # def apply_adaptor(args): |
| 79 | + # try: |
| 80 | + # result = args[0](args[1]) |
| 81 | + # except Exception as err: |
| 82 | + # # Catch any possible exception and store error message in case all adaptors fail |
| 83 | + # logger.debug(f"Adaptor Error ({args}): {err}") |
| 84 | + # result = [] |
| 85 | + # return result |
| 86 | + |
| 87 | + # results = pool.map( |
| 88 | + # apply_adaptor, |
| 89 | + # ((adaptor, request) for adaptor, request in these_requests.items()), |
| 90 | + # ) |
| 91 | + |
| 92 | + if len(results) == 0: |
| 93 | + raise RuntimeError( |
| 94 | + "MultiAdaptor returned no results, the error logs of the sub-adaptors is as follows:\n" |
| 95 | + f"{yaml.safe_dump(exception_logs)}" |
| 96 | + ) |
| 97 | + |
| 98 | + # return self.merge_results(results, prefix=self.collection_id) |
| 99 | + # close files |
| 100 | + [res.close() for res in results] |
| 101 | + # get the paths |
| 102 | + paths = [res.name for res in results] |
| 103 | + |
| 104 | + download_kwargs = dict( |
| 105 | + base_target=f"{self.collection_id}-{hash(tuple(results))}" |
| 106 | + ) |
| 107 | + |
| 108 | + return download_tools.DOWNLOAD_FORMATS[download_format]( |
| 109 | + paths, **download_kwargs |
| 110 | + ) |
0 commit comments