diff --git a/containers/sat/Containerfile b/containers/sat/Containerfile index ff76556..c974753 100644 --- a/containers/sat/Containerfile +++ b/containers/sat/Containerfile @@ -11,7 +11,7 @@ ENV GDAL_CONFIG=/venv/bin/gdal-config # Build the virtualenv FROM build-venv as install-reqs RUN /venv/bin/python -m pip install -q -U diskcache pyproj \ - pyresample xarray pyyaml ocf_blosc2 eumdac requests dask zarr + pyresample xarray pyyaml ocf_blosc2 eumdac requests dask zarr tqdm # Copy the virtualenv into a distroless image FROM gcr.io/distroless/python3-debian11 diff --git a/containers/sat/download_process_sat.py b/containers/sat/download_process_sat.py index f8fcf66..deaa67b 100644 --- a/containers/sat/download_process_sat.py +++ b/containers/sat/download_process_sat.py @@ -6,33 +6,29 @@ import argparse import dataclasses import datetime as dt -import itertools -import json import logging import os import pathlib import shutil import sys import traceback -from multiprocessing import Pool, cpu_count +from collections.abc import Iterator from typing import Literal import dask.delayed -import dask.distributed import dask.diagnostics +import dask.distributed import eumdac import eumdac.cli +import eumdac.product import numpy as np import pandas as pd import pyproj -import pyresample -import satpy.dataset.dataid import xarray as xr -import yaml -import zarr from ocf_blosc2 import Blosc2 - from satpy import Scene +from tqdm import tqdm +import zarr if sys.stdout.isatty(): # Simple logging for terminals @@ -57,6 +53,7 @@ datefmt="%Y-%m-%dT%H:%M:%S", ) +# Reduce verbosity of dependacies for logger in [ "cfgrib", "charset_normalizer", @@ -69,6 +66,7 @@ "urllib3", ]: logging.getLogger(logger).setLevel(logging.ERROR) +np.seterr(divide="ignore") log = logging.getLogger("sat-etl") @@ -162,165 +160,340 @@ class Channel: ], } - -def download_scans( +def get_products_iterator( sat_config: Config, - folder: pathlib.Path, - scan_time: pd.Timestamp, + start: dt.datetime, + end: dt.datetime, token: eumdac.AccessToken, -) -> list[pathlib.Path]: - """Download satellite scans for a satellite at a given time. +) -> tuple[Iterator[eumdac.product.Product], int]: + """Get an iterator over the products for a given satellite in a given time range. + + Checks that the number of products returned matches the expected number of products. + """ + log.info( + f"Searching for products between {start!s} and {end!s} " + f"for {sat_config.product_id} ", + ) + expected_products_count = int((end - start) / pd.Timedelta(sat_config.cadence)) + datastore = eumdac.DataStore(token) + collection = datastore.get_collection(sat_config.product_id) + search_results: eumdac.SearchResults = collection.search( + dtstart=start, + dtend=end, + sort="start,time,1", # Sort by ascending start time + ) + log.info( + f"Found {search_results.total_results}/{expected_products_count} products " + f"for {sat_config.product_id} ", + ) + return search_results.__iter__(), search_results.total_results + + +def download_nat( + product: eumdac.product.Product, + folder: pathlib.Path, + retries: int = 6, +) -> pathlib.Path | None: + """Download a product to a folder. + + EUMDAC products are collections of files, with a `.nat` file containing the data, + and with `.xml` files containing metadata. This function only downloads the `.nat` files, + skipping any files that are already present in the folder. Args: - sat_config: Configuration for the satellite. - folder: Folder to download the files to. - scan_time: Time to download the files for. - token: EUMETSTAT access token. + product: Product to download. + folder: Folder to download the product to. + retries: Number of times to retry downloading the product. Returns: - List of downloaded files. + Path to the downloaded file, or None if the download failed. """ - files: list[pathlib.Path] = [] - - # Download - window_start: pd.Timestamp = scan_time - pd.Timedelta(sat_config.cadence) - window_end: pd.Timestamp = scan_time + pd.Timedelta(sat_config.cadence) - - try: - datastore = eumdac.DataStore(token) - collection = datastore.get_collection(sat_config.product_id) - search_results = collection.search( - dtstart=window_start.to_pydatetime(), - dtend=window_end.to_pydatetime(), + folder.mkdir(parents=True, exist_ok=True) + nat_files: list[str] = [p for p in product.entries if p.endswith(".nat")] + if len(nat_files) != 1: + log.warning( + f"Product '{product}' contains {len(nat_files)} .nat files. " + "Expected 1. Skipping download.", ) - except Exception as e: - log.error(f"Error finding products: {e}") - return [] - - products_count: int = 0 - for product in search_results: - for entry in list(filter(lambda p: p.endswith(".nat"), product.entries)): - filepath: pathlib.Path = folder / entry - # Prevent downloading existing files - if filepath.exists(): - log.debug("Skipping existing file: {filepath}") - files.append(filepath) - continue - # Try download a few times - attempts: int = 1 - while attempts < 6: - try: - folder.mkdir(parents=True, exist_ok=True) - with ( - product.open(entry) as fsrc, - filepath.open("wb") as fdst, - ): - shutil.copyfileobj(fsrc, fdst) - files.append(filepath) - attempts = 1000 - except Exception as e: - log.warning( - f"Error downloading product '{product}' (attempt {attempts}): '{e}'", - ) - attempts += 1 - products_count += 1 - - if products_count == 0: - log.warning(f"No products found for {scan_time}") - - return files + return None + nat_filename: str = nat_files[0] -def _fname_to_scantime(fname: str) -> dt.datetime: - """Converts a filename to a datetime. + filepath: pathlib.Path = folder / nat_filename + if filepath.exists(): + log.debug(f"Skipping existing file: {filepath}") + return filepath - Files are of the form: - `MSG2-SEVI-MSG15-0100-NA-20230910221240.874000000Z-NA.nat` - So determine the time from the first element split by '.'. - """ - return dt.datetime.strptime(fname.split(".")[0][-14:], "%Y%m%d%H%M%S") + for i in range(retries): + try: + with (product.open(nat_filename) as fsrc, filepath.open("wb") as fdst): + shutil.copyfileobj(fsrc, fdst, length=16 * 1024) + return filepath + except Exception as e: + log.warning( + f"Error downloading product '{product}' (attempt {i}/{retries}): '{e}'", + ) + + log.error(f"Failed to download product '{product}' after {retries} attempts.") + return None -def process_scans( +def process_nat( sat_config: Config, - folder: pathlib.Path, - start: dt.date, - end: dt.date, + path: pathlib.Path, dstype: Literal["hrv", "nonhrv"], -) -> str: - """Process the downloaded scans into a zarr store. +) -> xr.DataArray | None: + """Process a `.nat` file into an xarray dataset. Args: - sat_config: Configuration for the satellite. - folder: Folder to download the files to. - start: Start date for the processing. - end: End date for the processing. + path: Path to the `.nat` file to open. dstype: Type of data to process (hrv or nonhrv). """ - # Check zarr file exists for the month - zarr_path: pathlib.Path = folder.parent / start.strftime(sat_config.zarr_fmtstr[dstype]) - zarr_times: list[dt.datetime] = [] - if zarr_path.exists(): - zarr_times = xr.open_zarr(zarr_path, consolidated=True).sortby("time").time.values.tolist() - log.debug(f"Zarr store already exists at {zarr_path} for {zarr_times[0]}-{zarr_times[-1]}") - else: - log.debug(f"Zarr store does not exist at {zarr_path}") - - # Get native files in order - native_files: list[pathlib.Path] = list(folder.glob("*.nat")) - native_files.sort() - wanted_files = [f for f in native_files if start <= _fname_to_scantime(f.name) < end] - log.info(f"Found {len(wanted_files)} native files within date range at {folder.as_posix()}") - - # Convert native files to xarray datasets - # * Append to the monthly zarr in hourly chunks - datasets: list[xr.Dataset] = [] - i: int - f: pathlib.Path - for i, f in enumerate(wanted_files): - try: - # TODO: This method of passing the zarr times to the open function leaves a lot to be desired - # Firstly, if the times are not passed in sorted order then the created 12-dataset chunks - # may have missed times in them. Secondly, determining the time still requires opening and - # converting the file which is probably slow. Better to skip search for files whose times - # are already in the Zarr store in the first place and bypass the entire pipeline. - dataset: xr.Dataset | None = _open_and_scale_data(zarr_times, f.as_posix(), dstype) - except Exception as e: - log.error(f"Error opening/scaling data for file {f}: {e}") - continue - if dataset is not None: - dataset = _preprocess_function(dataset) - datasets.append(dataset) - # Append to zarrs in hourly chunks - # * This is so zarr doesn't complain about mismatching chunk sizes - if len(datasets) == int(pd.Timedelta("1h") / pd.Timedelta(sat_config.cadence)): - if pathlib.Path(zarr_path).exists(): - log.debug(f"Appending to existing zarr store at {zarr_path}") - mode = "a" - else: - log.debug(f"Creating new zarr store at {zarr_path}") - mode = "w" - concat_ds: xr.Dataset = xr.concat(datasets, dim="time") - _write_to_zarr( - concat_ds, - zarr_path.as_posix(), - mode, - chunks={ - "time": len(datasets), - "x_geostationary": -1, - "y_geostationary": -1, - "variable": 1, - }, - ) - datasets = [] + # The reader is the same for each satellite as the sensor is the same + # * Hence "seviri" in all cases + try: + scene = Scene(filenames={"seviri_l1b_native": [path.as_posix()]}) + scene.load([c.variable for c in CHANNELS[dstype]]) + except Exception as e: + raise OSError(f"Error reading '{path!s}' as satpy Scene: {e}") from e - log.info(f"Process loop [{dstype}]: {i+1}/{len(wanted_files)}") + try: + da: xr.DataArray = _convert_scene_to_dataarray( + scene, + band=CHANNELS[dstype][0].variable, + area="RSS", + calculate_osgb=False, + ) + except Exception as e: + raise ValueError(f"Error converting '{path!s}' to DataArray: {e}") from e - # Consolidate zarr metadata - if pathlib.Path(zarr_path).exists(): - _rewrite_zarr_times(zarr_path.as_posix()) + # Rescale the data, save as dataarray + # TODO: Left over from Jacob, probbaly don't want this + try: + da = _rescale(da, CHANNELS[dstype]) + except Exception as e: + raise ValueError(f"Error rescaling dataarray: {e}") from e + + # Reorder the coordinates, and set the data type + da = da.transpose("time", "y_geostationary", "x_geostationary", "variable") + da = da.astype(np.float16) - check_data_quality(xr.open_zarr(zarr_path, consolidated=True)) + return da - return dstype +def write_to_zarr( + da: xr.DataArray, + zarr_path: pathlib.Path, +) -> None: + """Write the given data array to the given zarr store. + + If a Zarr store already exists at the given path, the DataArray will be appended to it. + + Any attributes on the dataarray object are serialized to json-compatible strings. + """ + mode = "a" if zarr_path.exists() else "w" + extra_kwargs = { + "append_dim": "time", + } if mode == "a" else { + "encoding": { + "data": {"compressor": Blosc2("zstd", clevel=5)}, + "time": {"units": "nanoseconds since 1970-01-01"}, + }, + } + # Convert attributes to be json serializable + for key, value in da.attrs.items(): + if isinstance(value, dict): + # Convert np.float32 to Python floats (otherwise yaml.dump complains) + for inner_key in value: + inner_value = value[inner_key] + if isinstance(inner_value, np.floating): + value[inner_key] = float(inner_value) + da.attrs[key] = yaml.dump(value) + if isinstance(value, bool | np.bool_): + da.attrs[key] = str(value) + if isinstance(value, pyresample.geometry.AreaDefinition): + da.attrs[key] = value.dump() + # Convert datetimes + if isinstance(value, dt.datetime): + da.attrs[key] = value.isoformat() + + try: + write_job = da.chunk({ + "time": 1, + "x_geostationary": -1, + "y_geostationary": -1, + "variable": 1, + }).to_dataset( + name="data", + promote_attrs=True, + ).to_zarr( + store=zarr_path, + compute=True, + consolidated=True, + mode=mode, + **extra_kwargs, + ) + except Exception as e: + log.error(f"Error writing dataset to zarr store {zarr_path} with mode {mode}: {e}") + traceback.print_tb(e.__traceback__) + + return None + +#def download_scans( +# sat_config: Config, +# folder: pathlib.Path, +# scan_time: pd.Timestamp, +# token: eumdac.AccessToken, +#) -> list[pathlib.Path]: +# """Download satellite scans for a satellite at a given time. +# +# Args: +# sat_config: Configuration for the satellite. +# folder: Folder to download the files to. +# scan_time: Time to download the files for. +# token: EUMETSTAT access token. +# +# Returns: +# List of downloaded files. +# """ +# files: list[pathlib.Path] = [] +# +# # Download +# window_start: pd.Timestamp = scan_time - pd.Timedelta(sat_config.cadence) +# window_end: pd.Timestamp = scan_time + pd.Timedelta(sat_config.cadence) +# +# try: +# datastore = eumdac.DataStore(token) +# collection = datastore.get_collection(sat_config.product_id) +# search_results = collection.search( +# dtstart=window_start.to_pydatetime(), +# dtend=window_end.to_pydatetime(), +# ) +# except Exception as e: +# log.error(f"Error finding products: {e}") +# return [] +# +# products_count: int = 0 +# for product in search_results: +# for entry in list(filter(lambda p: p.endswith(".nat"), product.entries)): +# filepath: pathlib.Path = folder / entry +# # Prevent downloading existing files +# if filepath.exists(): +# log.debug("Skipping existing file: {filepath}") +# files.append(filepath) +# continue +# # Try download a few times +# attempts: int = 1 +# while attempts < 6: +# try: +# folder.mkdir(parents=True, exist_ok=True) +# with ( +# product.open(entry) as fsrc, +# filepath.open("wb") as fdst, +# ): +# shutil.copyfileobj(fsrc, fdst) +# files.append(filepath) +# attempts = 1000 +# except Exception as e: +# log.warning( +# f"Error downloading product '{product}' (attempt {attempts}): '{e}'", +# ) +# attempts += 1 +# products_count += 1 +# +# if products_count == 0: +# log.warning(f"No products found for {scan_time}") +# +# return files + +def _fname_to_scantime(fname: str) -> dt.datetime: + """Converts a filename to a datetime. + + Files are of the form: + `MSGX-SEVI-MSG15-0100-NA-20230910221240.874000000Z-NA.nat` + So determine the time from the first element split by '.'. + """ + return dt.datetime.strptime(fname.split(".")[0][-14:], "%Y%m%d%H%M%S") + +#def process_scans( +# sat_config: Config, +# folder: pathlib.Path, +# start: dt.date, +# end: dt.date, +# dstype: Literal["hrv", "nonhrv"], +#) -> str: +# """Process the downloaded scans into a zarr store. +# +# Args: +# sat_config: Configuration for the satellite. +# folder: Folder to download the files to. +# start: Start date for the processing. +# end: End date for the processing. +# dstype: Type of data to process (hrv or nonhrv). +# """ +# # Check zarr file exists for the month +# zarr_path: pathlib.Path = folder.parent / start.strftime(sat_config.zarr_fmtstr[dstype]) +# zarr_times: list[dt.datetime] = [] +# if zarr_path.exists(): +# zarr_times = xr.open_zarr(zarr_path, consolidated=True).sortby("time").time.values.tolist() +# log.debug(f"Zarr store already exists at {zarr_path} for {zarr_times[0]}-{zarr_times[-1]}") +# else: +# log.debug(f"Zarr store does not exist at {zarr_path}") +# +# # Get native files in order +# native_files: list[pathlib.Path] = list(folder.glob("*.nat")) +# native_files.sort() +# wanted_files = [f for f in native_files if start <= _fname_to_scantime(f.name) < end] +# log.info(f"Found {len(wanted_files)} native files within date range at {folder.as_posix()}") +# +# # Convert native files to xarray datasets +# # * Append to the monthly zarr in hourly chunks +# datasets: list[xr.Dataset] = [] +# i: int +# f: pathlib.Path +# for i, f in enumerate(wanted_files): +# try: +# # TODO: This method of passing the zarr times to the open function leaves a lot to be desired +# # Firstly, if the times are not passed in sorted order then the created 12-dataset chunks +# # may have missed times in them. Secondly, determining the time still requires opening and +# # converting the file which is probably slow. Better to skip search for files whose times +# # are already in the Zarr store in the first place and bypass the entire pipeline. +# dataset: xr.Dataset | None = _open_and_scale_data(zarr_times, f.as_posix(), dstype) +# except Exception as e: +# log.error(f"Error opening/scaling data for file {f}: {e}") +# continue +# if dataset is not None: +# dataset = _preprocess_function(dataset) +# datasets.append(dataset) +# # Append to zarrs in hourly chunks +# # * This is so zarr doesn't complain about mismatching chunk sizes +# if len(datasets) == int(pd.Timedelta("1h") / pd.Timedelta(sat_config.cadence)): +# if pathlib.Path(zarr_path).exists(): +# log.debug(f"Appending to existing zarr store at {zarr_path}") +# mode = "a" +# else: +# log.debug(f"Creating new zarr store at {zarr_path}") +# mode = "w" +# concat_ds: xr.Dataset = xr.concat(datasets, dim="time") +# _write_to_zarr( +# concat_ds, +# zarr_path.as_posix(), +# mode, +# chunks={ +# "time": len(datasets), +# "x_geostationary": -1, +# "y_geostationary": -1, +# "variable": 1, +# }, +# ) +# datasets = [] +# +# log.info(f"Process loop [{dstype}]: {i+1}/{len(wanted_files)}") +# +# # Consolidate zarr metadata +# if pathlib.Path(zarr_path).exists(): +# _rewrite_zarr_times(zarr_path.as_posix()) +# +# check_data_quality(xr.open_zarr(zarr_path, consolidated=True)) +# +# return dstype def _gen_token() -> eumdac.AccessToken: @@ -331,6 +504,27 @@ def _gen_token() -> eumdac.AccessToken: return token +def _get_attrs_from_scene(scene: Scene) -> dict[str, str]: + """Get the attributes from a Scene object.""" + for key, value in attrs.items(): + # Convert Dicts + if isinstance(value, dict): + # Convert np.float32 to Python floats (otherwise yaml.dump complains) + for inner_key in value: + inner_value = value[inner_key] + if isinstance(inner_value, np.floating): + value[inner_key] = float(inner_value) + attrs[key] = yaml.dump(value) + # Convert Numpy bools + if isinstance(value, bool | np.bool_): + attrs[key] = str(value) + # Convert area + if isinstance(value, pyresample.geometry.AreaDefinition): + attrs[key] = value.dump() + # Convert datetimes + if isinstance(value, dt.datetime): + attrs[key] = value.isoformat() + def _convert_scene_to_dataarray( scene: Scene, @@ -376,53 +570,53 @@ def _convert_scene_to_dataarray( if attr not in ["area", "_satpy_id"]: data_attrs[new_name] = scene[channel].attrs[attr].__repr__() - dataset: xr.Dataset = scene.to_xarray_dataset() - dataarray = dataset.to_array() + ds: xr.Dataset = scene.to_xarray_dataset() + da = ds.to_array() # Lat and Lon are the same for all the channels now if calculate_osgb: lon, lat = scene[band].attrs["area"].get_lonlats() osgb_x, osgb_y = transformer.transform(lat, lon) # Assign x_osgb and y_osgb and set some attributes - dataarray = dataarray.assign_coords( + da = da.assign_coords( x_osgb=(("y", "x"), np.float32(osgb_x)), y_osgb=(("y", "x"), np.float32(osgb_y)), ) for name in ["x_osgb", "y_osgb"]: - dataarray[name].attrs = { + da[name].attrs = { "units": "meter", "coordinate_reference_system": "OSGB", } - dataarray.x_osgb.attrs["name"] = "Easting" - dataarray.y_osgb.attrs["name"] = "Northing" + da.x_osgb.attrs["name"] = "Easting" + da.y_osgb.attrs["name"] = "Northing" for name in ["x", "y"]: - dataarray[name].attrs["coordinate_reference_system"] = "geostationary" + da[name].attrs["coordinate_reference_system"] = "geostationary" log.debug("Calculated OSGB") # Round to the nearest 5 minutes - data_attrs["end_time"] = pd.Timestamp(dataarray.attrs["end_time"]).round("5 min").__str__() - dataarray.attrs = data_attrs + data_attrs["end_time"] = pd.Timestamp(da.attrs["end_time"]).round("5 min").__str__() + da.attrs.update(data_attrs) # Rename x and y to make clear the coordinate system they are in - dataarray = dataarray.rename({"x": "x_geostationary", "y": "y_geostationary"}) - if "time" not in dataarray.dims: - time = pd.to_datetime(pd.Timestamp(dataarray.attrs["end_time"]).round("5 min")) - dataarray = dataarray.assign_coords({"time": time}).expand_dims("time") + da = da.rename({"x": "x_geostationary", "y": "y_geostationary"}) + if "time" not in da.dims: + time = pd.to_datetime(pd.Timestamp(da.attrs["end_time"]).round("5 min")) + da = da.assign_coords({"time": time}).expand_dims("time") - del dataarray["crs"] + del da["crs"] del scene log.debug("Finished conversion") - return dataarray + return da -def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray: +def _rescale(da: xr.DataArray, channels: list[Channel]) -> xr.DataArray: """Rescale Xarray DataArray so all values lie in the range [0, 1]. Warning: The original `dataarray` will be modified in-place. Args: - dataarray: DataArray to rescale. + da: DataArray to rescale. Dims MUST be named ('time', 'x_geostationary', 'y_geostationary', 'variable')! channels: List of Channel objects with minimum and maximum values for each channel. @@ -430,7 +624,7 @@ def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray: The DataArray rescaled to [0, 1]. NaNs in the original `dataarray` will still be present in the returned dataarray. The returned DataArray will be float32. """ - dataarray = dataarray.reindex( + da = da.reindex( {"variable": [c.variable for c in channels]}, ).transpose( "time", @@ -440,154 +634,154 @@ def _rescale(dataarray: xr.DataArray, channels: list[Channel]) -> xr.DataArray: ) # For each channel, subtract the minimum and divide by the range - dataarray -= [c.minimum for c in channels] - dataarray /= [c.maximum - c.minimum for c in channels] + da -= [c.minimum for c in channels] + da /= [c.maximum - c.minimum for c in channels] # Since the mins and maxes are approximations, clip the values to [0, 1] - dataarray = dataarray.clip(min=0, max=1) - dataarray = dataarray.astype(np.float32) - return dataarray - - -def _open_and_scale_data( - zarr_times: list[dt.datetime], - f: str, - dstype: Literal["hrv", "nonhrv"], -) -> xr.Dataset | None: - """Opens a raw file and converts it to a normalised xarray dataset. - - Args: - zarr_times: List of times already in the zarr store. - f: Path to the file to open. - dstype: Type of data to process (hrv or nonhrv). - """ - # The reader is the same for each satellite as the sensor is the same - # * Hence "seviri" in all cases - try: - scene = Scene(filenames={"seviri_l1b_native": [f]}) - scene.load([c.variable for c in CHANNELS[dstype]]) - except Exception as e: - raise OSError(f"Error loading scene from file {f}: {e}") from e - - try: - da: xr.DataArray = _convert_scene_to_dataarray( - scene, - band=CHANNELS[dstype][0].variable, - area="RSS", - calculate_osgb=False, - ) - except Exception as e: - log.error(f"Error converting scene to dataarray: {e}") - return None - - # Don't proceed if the dataarray time is already present in the zarr store - if da.time.values[0] in zarr_times: - log.debug(f"Skipping: {da.time.values[0]}") - return None - - # Rescale the data, save as dataset - try: - da = _rescale(da, CHANNELS[dstype]) - except Exception as e: - log.error(f"Error rescaling dataarray: {e}") - return None - - da = da.transpose("time", "y_geostationary", "x_geostationary", "variable") - ds: xr.Dataset = da.to_dataset(name="data", promote_attrs=True) - ds["data"] = ds["data"].astype(np.float16) - - return ds - - -def _preprocess_function(xr_data: xr.Dataset) -> xr.Dataset: - """Updates the coordinates for the given dataset.""" - attrs = xr_data.attrs - y_coords = xr_data.coords["y_geostationary"].values - x_coords = xr_data.coords["x_geostationary"].values - x_dataarray: xr.DataArray = xr.DataArray( - data=np.expand_dims(xr_data.coords["x_geostationary"].values, axis=0), - dims=["time", "x_geostationary"], - coords={"time": xr_data.coords["time"].values, "x_geostationary": x_coords}, - ) - y_dataarray: xr.DataArray = xr.DataArray( - data=np.expand_dims(xr_data.coords["y_geostationary"].values, axis=0), - dims=["time", "y_geostationary"], - coords={"time": xr_data.coords["time"].values, "y_geostationary": y_coords}, - ) - xr_data["x_geostationary_coordinates"] = x_dataarray - xr_data["y_geostationary_coordinates"] = y_dataarray - xr_data.attrs = attrs - return xr_data - - -def _write_to_zarr(dataset: xr.Dataset, zarr_name: str, mode: str, chunks: dict) -> None: - """Writes the given dataset to the given zarr store.""" - log.info("Writing to Zarr") - mode_extra_kwargs: dict[str, dict] = { - "a": {"append_dim": "time"}, - "w": { - "encoding": { - "data": { - "compressor": Blosc2("zstd", clevel=5), - }, - "time": {"units": "nanoseconds since 1970-01-01"}, - }, - }, - } - extra_kwargs = mode_extra_kwargs[mode] - sliced_ds: xr.Dataset = dataset.isel(x_geostationary=slice(0, 5548)).chunk(chunks) - try: - write_job = sliced_ds.to_zarr( - store=zarr_name, - compute=False, - consolidated=True, - mode=mode, - **extra_kwargs, - ) - with dask.diagnostics.ProgressBar(): - write_job.compute() - except Exception as e: - log.error(f"Error writing dataset to zarr store {zarr_name} with mode {mode}: {e}") - traceback.print_tb(e.__traceback__) - return None - - -def _rewrite_zarr_times(output_name: str) -> None: - """Rewrites the time coordinates in the given zarr store.""" - # Combine time coords - ds = xr.open_zarr(output_name, consolidated=True) - - # Prevent numcodecs string error - # See https://github.com/pydata/xarray/issues/3476#issuecomment-1205346130 - for v in list(ds.coords.keys()): - if ds.coords[v].dtype == object: - ds[v].encoding.clear() - for v in list(ds.variables.keys()): - if ds[v].dtype == object: - ds[v].encoding.clear() - - del ds["data"] - if "x_geostationary_coordinates" in ds: - del ds["x_geostationary_coordinates"] - if "y_geostationary_coordinates" in ds: - del ds["y_geostationary_coordinates"] - # Need to remove these encodings to avoid chunking - del ds.time.encoding["chunks"] - del ds.time.encoding["preferred_chunks"] - ds.to_zarr(f"{output_name.split('.zarr')[0]}_coord.zarr", consolidated=True, mode="w") - # Remove current time ones - shutil.rmtree(f"{output_name}/time/") - # Add new time ones - shutil.copytree(f"{output_name.split('.zarr')[0]}_coord.zarr/time", f"{output_name}/time") - - # Now replace the part of the .zmetadata with the part of the .zmetadata from the new coord one - with open(f"{output_name}/.zmetadata") as f: - data = json.load(f) - with open(f"{output_name.split('.zarr')[0]}_coord.zarr/.zmetadata") as f2: - coord_data = json.load(f2) - data["metadata"]["time/.zarray"] = coord_data["metadata"]["time/.zarray"] - with open(f"{output_name}/.zmetadata", "w") as f: - json.dump(data, f) - zarr.consolidate_metadata(output_name) + da = da.clip(min=0, max=1) + da = da.astype(np.float32) + return da + + +#def _open_and_scale_data( +# zarr_times: list[dt.datetime], +# f: str, +# dstype: Literal["hrv", "nonhrv"], +#) -> xr.Dataset | None: +# """Opens a raw file and converts it to a normalised xarray dataset. +# +# Args: +# zarr_times: List of times already in the zarr store. +# f: Path to the file to open. +# dstype: Type of data to process (hrv or nonhrv). +# """ +# # The reader is the same for each satellite as the sensor is the same +# # * Hence "seviri" in all cases +# try: +# scene = Scene(filenames={"seviri_l1b_native": [f]}) +# scene.load([c.variable for c in CHANNELS[dstype]]) +# except Exception as e: +# raise OSError(f"Error loading scene from file {f}: {e}") from e +# +# try: +# da: xr.DataArray = _convert_scene_to_dataarray( +# scene, +# band=CHANNELS[dstype][0].variable, +# area="RSS", +# calculate_osgb=False, +# ) +# except Exception as e: +# log.error(f"Error converting scene to dataarray: {e}") +# return None +# +# # Don't proceed if the dataarray time is already present in the zarr store +# if da.time.values[0] in zarr_times: +# log.debug(f"Skipping: {da.time.values[0]}") +# return None +# +# # Rescale the data, save as dataset +# try: +# da = _rescale(da, CHANNELS[dstype]) +# except Exception as e: +# log.error(f"Error rescaling dataarray: {e}") +# return None +# +# da = da.transpose("time", "y_geostationary", "x_geostationary", "variable") +# ds: xr.Dataset = da.to_dataset(name="data", promote_attrs=True) +# ds["data"] = ds["data"].astype(np.float16) +# +# return ds + + +#def _preprocess_function(xr_data: xr.Dataset) -> xr.Dataset: +# """Updates the coordinates for the given dataset.""" +# # TODO: Understand why this is necessary! +# attrs = xr_data.attrs +# y_coords = xr_data.coords["y_geostationary"].values +# x_coords = xr_data.coords["x_geostationary"].values +# x_dataarray: xr.DataArray = xr.DataArray( +# data=np.expand_dims(xr_data.coords["x_geostationary"].values, axis=0), +# dims=["time", "x_geostationary"], +# coords={"time": xr_data.coords["time"].values, "x_geostationary": x_coords}, +# ) +# y_dataarray: xr.DataArray = xr.DataArray( +# data=np.expand_dims(xr_data.coords["y_geostationary"].values, axis=0), +# dims=["time", "y_geostationary"], +# coords={"time": xr_data.coords["time"].values, "y_geostationary": y_coords}, +# ) +# xr_data["x_geostationary_coordinates"] = x_dataarray +# xr_data["y_geostationary_coordinates"] = y_dataarray +# xr_data.attrs = attrs +# return xr_data + + +#def _write_to_zarr(dataset: xr.Dataset, zarr_name: str, mode: str, chunks: dict) -> None: +# """Writes the given dataset to the given zarr store.""" +# log.info("Writing to Zarr") +# mode_extra_kwargs: dict[str, dict] = { +# "a": {"append_dim": "time"}, +# "w": { +# "encoding": { +# "data": { +# "compressor": Blosc2("zstd", clevel=5), +# }, +# "time": {"units": "nanoseconds since 1970-01-01"}, +# }, +# }, +# } +# extra_kwargs = mode_extra_kwargs[mode] +# sliced_ds: xr.Dataset = dataset.isel(x_geostationary=slice(0, 5548)).chunk(chunks) +# try: +# write_job = sliced_ds.to_zarr( +# store=zarr_name, +# compute=False, +# consolidated=True, +# mode=mode, +# **extra_kwargs, +# ) +# with dask.diagnostics.ProgressBar(): +# write_job.compute() +# except Exception as e: +# log.error(f"Error writing dataset to zarr store {zarr_name} with mode {mode}: {e}") +# traceback.print_tb(e.__traceback__) +# return None + +#def _rewrite_zarr_times(output_name: str) -> None: +# """Rewrites the time coordinates in the given zarr store.""" +# # Combine time coords +# ds = xr.open_zarr(output_name, consolidated=True) +# +# # Prevent numcodecs string error +# # See https://github.com/pydata/xarray/issues/3476#issuecomment-1205346130 +# for v in list(ds.coords.keys()): +# if ds.coords[v].dtype == object: +# ds[v].encoding.clear() +# for v in list(ds.variables.keys()): +# if ds[v].dtype == object: +# ds[v].encoding.clear() +# +# del ds["data"] +# if "x_geostationary_coordinates" in ds: +# del ds["x_geostationary_coordinates"] +# if "y_geostationary_coordinates" in ds: +# del ds["y_geostationary_coordinates"] +# # Need to remove these encodings to avoid chunking +# del ds.time.encoding["chunks"] +# del ds.time.encoding["preferred_chunks"] +# ds.to_zarr(f"{output_name.split('.zarr')[0]}_coord.zarr", consolidated=True, mode="w") +# # Remove current time ones +# shutil.rmtree(f"{output_name}/time/") +# # Add new time ones +# shutil.copytree(f"{output_name.split('.zarr')[0]}_coord.zarr/time", f"{output_name}/time") +# +# # Now replace the part of the .zmetadata with the part of the .zmetadata from the new coord one +# with open(f"{output_name}/.zmetadata") as f: +# data = json.load(f) +# with open(f"{output_name.split('.zarr')[0]}_coord.zarr/.zmetadata") as f2: +# coord_data = json.load(f2) +# data["metadata"]["time/.zarray"] = coord_data["metadata"]["time/.zarray"] +# with open(f"{output_name}/.zmetadata", "w") as f: +# json.dump(data, f) +# zarr.consolidate_metadata(output_name) parser = argparse.ArgumentParser( @@ -601,6 +795,12 @@ def _rewrite_zarr_times(output_name: str) -> None: type=str, choices=list(CONFIGS.keys()), ) +parser.add_argument( + "--hrv", + help="Download HRV instead of non-HRV data", + action="store_true", + default=False, +) parser.add_argument( "--path", "-p", help="Path to store the downloaded data", @@ -630,7 +830,10 @@ def _rewrite_zarr_times(output_name: str) -> None: def check_data_quality(ds: xr.Dataset) -> None: - """Check the quality of the data in the given dataset.""" + """Check the quality of the data in the given dataset. + + Looks for the number of NaNs in the data over important regions. + """ def _calc_null_percentage(data: np.ndarray): nulls = np.isnan(data) @@ -638,12 +841,16 @@ def _calc_null_percentage(data: np.ndarray): result = xr.apply_ufunc( _calc_null_percentage, - ds.data_vars["data"], + ds.data_vars["data"].sel( + x_geostationary=slice(-480_064.6, -996_133.85), + y_geostationary=slice(4_512_606.3, 5_058_679.8), + ), input_core_dims=[["x_geostationary", "y_geostationary"]], vectorize=True, + dask="parallelized", ) - num_images_failing_nulls_threshold = (result > 0.05).sum().item() + num_images_failing_nulls_threshold = (result > 0.05).sum().values num_images = result.size log.info( f"{num_images_failing_nulls_threshold}/{num_images} " @@ -656,77 +863,94 @@ def run(args: argparse.Namespace) -> None: prog_start = dt.datetime.now(tz=dt.UTC) log.info(f"{prog_start!s}: Running with args: {args}") - # Get running folder from args - folder: pathlib.Path = args.path / args.sat - - # Get config for desired satellite + # Get values from args + folder: pathlib.Path = args.path sat_config = CONFIGS[args.sat] - - # Get start and end times for run start: dt.datetime = dt.datetime.strptime(args.month, "%Y-%m") - end: dt.datetime = \ - start.replace(month=start.month + 1) if start.month < 12 \ - else start.replace(year=start.year + 1, month=1) \ - - dt.timedelta(days=1) - scan_times: list[pd.Timestamp] = pd.date_range( + end: dt.datetime = (start + pd.DateOffset(months=1, minutes=-1)).to_pydatetime() + dstype: str = "hrv" if args.hrv else "nonhrv" + + product_iter, total = get_products_iterator( + sat_config=sat_config, start=start, end=end, - freq=sat_config.cadence, - inclusive="left", - ).tolist() - - # Estimate average runtime - secs_per_scan: int = 90 - expected_runtime = pd.Timedelta(secs_per_scan * len(scan_times), "seconds") - log.info( - f"Downloading {len(scan_times)} scans ({start} - {end}). " - f"Expected runtime: {expected_runtime!s}" + token=_gen_token(), ) + # Use existing zarr store if it exists + ds: xr.Dataset | None = None + zarr_path = folder / start.strftime(sat_config.zarr_fmtstr[dstype]) + if zarr_path.exists(): + log.info(f"Using existing zarr store at '{zarr_path}'") + ds = xr.open_zarr(zarr_path, consolidated=True) + + # Iterate through all products in search + for product in tqdm(product_iter, total=total, miniters=50): + + # Skip products already present in store + if ds is not None: + product_time: dt.datetime = product.sensing_start.replace(second=0, microsecond=0) + if np.datetime64(product_time, "ns") in ds.coords["time"].values: + log.debug( + f"Skipping entry '{product!s}' as '{product_time}' already in store" + ) + continue + + # For non-existing products, download and process + nat_filepath = download_nat( + product=product, + folder=folder / args.sat, + ) + if nat_filepath is None: + raise OSError(f"Failed to download product '{product}'") + da = process_nat(sat_config, nat_filepath, dstype) + write_to_zarr(da=da, zarr_path=zarr_path) + + runtime = dt.datetime.now(tz=dt.UTC) - prog_start + log.info(f"Completed archive for args: {args} in {runtime!s}.") + # Download data # We only parallelize if we have a number of files larger than the cpu count - token = _gen_token() - raw_paths: list[pathlib.Path] = [] - if len(scan_times) > cpu_count(): - log.debug(f"Concurrency: {cpu_count()}") - pool = Pool(max(cpu_count(), 10)) # EUMDAC only allows for 10 concurrent requests - results: list[list[pathlib.Path]] = pool.starmap( - download_scans, - [(sat_config, folder, scan_time, token) for scan_time in scan_times], - ) - pool.close() - pool.join() - raw_paths.extend(list(itertools.chain(*results))) - else: - for scan_time in scan_times: - result: list[pathlib.Path] = download_scans(sat_config, folder, scan_time, token) - if len(result) > 0: - raw_paths.extend(result) - - log.info(f"Downloaded {len(raw_paths)} files.") - log.info("Converting raw data to HRV and non-HRV Zarr Stores.") + # token = _gen_token() + # raw_paths: list[pathlib.Path] = [] + # if len(scan_times) > cpu_count(): + # log.debug(f"Concurrency: {cpu_count()}") + # pool = Pool(max(cpu_count(), 10)) # EUMDAC only allows for 10 concurrent requests + # results: list[list[pathlib.Path]] = pool.starmap( + # download_scans, + # [(sat_config, folder, scan_time, token) for scan_time in scan_times], + # ) + # pool.close() + # pool.join() + # raw_paths.extend(list(itertools.chain(*results))) + # else: + # for scan_time in scan_times: + # result: list[pathlib.Path] = download_scans(sat_config, folder, scan_time, token) + # if len(result) > 0: + # raw_paths.extend(result) + + # log.info(f"Downloaded {len(raw_paths)} files.") + # log.info("Converting raw data to HRV and non-HRV Zarr Stores.") # Process the HRV and non-HRV data concurrently if possible - completed_types: list[str] = [] - for t in ["hrv", "nonhrv"]: - log.info(f"Processing {t} data.") - completed_type = process_scans(sat_config, folder, start, end, t) - completed_types.append(completed_type) - for completed_type in completed_types: - log.info(f"Processed {completed_type} data.") - - # Calculate the new average time per timestamp - runtime: dt.timedelta = dt.datetime.now(tz=dt.UTC) - prog_start - new_average_secs_per_scan: int = int( - (secs_per_scan + (runtime.total_seconds() / len(scan_times))) / 2, - ) - log.info(f"Completed archive for args: {args}. ({new_average_secs_per_scan} seconds per scan).") + #completed_types: list[str] = [] + #for t in ["hrv", "nonhrv"]: + # log.info(f"Processing {t} data.") + # completed_type = process_scans(sat_config, folder, start, end, t) + # completed_types.append(completed_type) + #for completed_type in completed_types: + # log.info(f"Processed {completed_type} data.") + + ## Calculate the new average time per timestamp + #runtime: dt.timedelta = dt.datetime.now(tz=dt.UTC) - prog_start + #new_average_secs_per_scan: int = int( + # (secs_per_scan + (runtime.total_seconds() / len(scan_times))) / 2, + #) + #log.info(f"Completed archive for args: {args}. ({new_average_secs_per_scan} seconds per scan).") if args.validate: - for t in completed_types: - zarr_path: pathlib.Path = folder.parent / start.strftime(sat_config.zarr_fmtstr[t]) - ds = xr.open_zarr(zarr_path, consolidated=True) - check_data_quality(ds) + ds = xr.open_zarr(zarr_path, consolidated=True) + check_data_quality(ds) # Delete raw files, if desired if args.delete_raw: @@ -739,3 +963,4 @@ def run(args: argparse.Namespace) -> None: # Parse running args args = parser.parse_args() run(args) + diff --git a/containers/sat/test_download_process_sat.py b/containers/sat/test_download_process_sat.py index 9ac16ad..fde8578 100644 --- a/containers/sat/test_download_process_sat.py +++ b/containers/sat/test_download_process_sat.py @@ -25,17 +25,6 @@ class TestDownloadProcessSat(unittest.TestCase): def setUpClass(cls) -> None: TIMESTAMP = pd.Timestamp("2024-01-01T00:00:00Z") - token = dps._gen_token() - - for t in [TIMESTAMP + pd.Timedelta(t) for t in ["0m", "15m", "30m", "45m"]]: - paths = dps.download_scans( - sat_config=dps.CONFIGS["iodc"], - folder=pathlib.Path("/tmp/test_sat_data"), - scan_time=t, - token=token, - ) - cls.paths = paths - attrs: dict = { "end_time": TIMESTAMP + pd.Timedelta("15m"), "modifiers": (), @@ -76,8 +65,19 @@ def setUpClass(cls) -> None: ), } - def test_download_scans(self) -> None: - self.assertGreater(len(self.paths), 0) + def test_get_products_iterator(self) -> None: + """Test that the iterator returns the correct number of products.""" + token = dps._gen_token() + for config in dps.CONFIGS: + with self.subTest as t: + products_iter, total = dps._get_products_iterator( + sat_config=config, + start=pd.Timestamp("2024-01-01").to_pydatetime(), + end=(pd.Timestamp("2024-01-01") + pd.Timedelta(sat_config["cadence"])).to_pydatetime(), + token=token, + ) + t.assertEqual(total, 1) + def test_convert_scene_to_dataarray(self) -> None: scene = Scene(filenames={"seviri_l1b_native": [self.paths[0].as_posix()]}) @@ -114,6 +114,16 @@ def test_open_and_scale_data(self) -> None: self.assertDictEqual(dict(ds.sizes), dict(ds2.sizes)) self.assertNotEqual(dict(ds.attrs), {}) + def test_process_nat(self) -> None: + out: str = dps.process_nat( + dps.CONFIGS["iodc"], + pathlib.Path("/tmp/test_sat_data"), + pd.Timestamp("2024-01-01"), + pd.Timestamp("2024-01-02"), "nonhrv", + ) + + self.assertTrue(False) + def test_process_scans(self) -> None: out: str = dps.process_scans( @@ -125,3 +135,5 @@ def test_process_scans(self) -> None: self.assertTrue(False) +if __name__ == "__main__": + unittest.main()