Skip to content

Commit

Permalink
Merge pull request #3 from openclimatefix/dev
Browse files Browse the repository at this point in the history
Add dataloader
  • Loading branch information
dfulu authored Jul 3, 2024
2 parents 99166f8 + efb9a40 commit fc64df9
Show file tree
Hide file tree
Showing 8 changed files with 746 additions and 3 deletions.
18 changes: 15 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,18 @@ conda create -n sat_pred python=3.10
conda activate sat_pred
```

Install the dependencies
Clone this repo

```
pip install -r requirements.txt
git clone https://github.com/openclimatefix/sat_pred.git
```

Install the package and its dependencies

```
cd sat_pred
pip install -e .
```

Create a empty directory to save the satellite data to
Expand All @@ -39,4 +47,8 @@ The above script downloads all the satellite imagery from June 2020. The input a
Note that the above script creates a satellite dataset which is 21GB. On my machine it used about
12GB of RAM at its peak and took around 30 minutes to run.

See the notebook `plot_satellite_image_example.ipynb` for loading and plotting example

See the notebook `01-plot_satellite_image_example.ipynb` for loading and plotting example.


See the notebook `02-data_loader_demo.ipynb` for getting started with the dataloader.
407 changes: 407 additions & 0 deletions notebooks/02-data_loader_demo.ipynb

Large diffs are not rendered by default.

3 changes: 3 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
ocf_blosc2
ocf_datapipes
lightning
torch
torchvision
numpy
pandas
xarray
Expand Down
213 changes: 213 additions & 0 deletions sat_pred/dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
"""Dataset and DataModule for past and future satellite data"""

from typing import Union
from datetime import timedelta

import numpy as np
import pandas as pd
import xarray as xr

from torch.utils.data import Dataset, DataLoader
from lightning.pytorch import LightningDataModule

from ocf_datapipes.load.satellite import _get_single_sat_data
from ocf_datapipes.select.find_contiguous_t0_time_periods import (
find_contiguous_time_periods, find_contiguous_t0_time_periods
)


def minutes(m):
"""Timedelta of a number of minutes"""
return timedelta(minutes=m)


def load_satellite_zarrs(zarr_path):
"""Load the satellite data"""
if isinstance(zarr_path, (list, tuple)):
ds = xr.combine_nested(
[_get_single_sat_data(path) for path in zarr_path],
concat_dim="time",
combine_attrs="override",
join="override",
)
else:
ds = _get_single_sat_data(zarr_path)

return ds


def find_valid_t0_times(ds, history_mins, forecast_mins, sample_freq_mins):
"""Constuct an array of all t0 times which are valid considering the gaps in the sat data"""

# Find periods where we have contiguous time steps
contiguous_time_periods = find_contiguous_time_periods(
datetimes=pd.DatetimeIndex(ds.time),
min_seq_length=int((history_mins + forecast_mins) / sample_freq_mins) + 1,
max_gap_duration=minutes(sample_freq_mins),
)

# Find periods of valid init-times
contiguous_t0_periods = find_contiguous_t0_time_periods(
contiguous_time_periods=contiguous_time_periods,
history_duration=minutes(history_mins),
forecast_duration=minutes(forecast_mins),
)

valid_t0_times = []
for _, row in contiguous_t0_periods.iterrows():
start_dt = row["start_dt"]
end_dt = row["end_dt"]
valid_t0_times.append(pd.date_range(row["start_dt"], row["end_dt"], freq="5min"))

valid_t0_times = pd.to_datetime(np.concatenate(valid_t0_times))

return valid_t0_times


class SatelliteDataset(Dataset):
def __init__(
self,
zarr_path: Union[list, str],
start_time: str,
end_time: str,
history_mins: int,
forecast_mins: int,
sample_freq_mins: int,
):
"""A torch Dataset for loading past and future satellite data
Args:
zarr_path: Path the satellite data. Can be a string or list
start_time: The satellite data is filtered to exclude timestamps before this
end_time: The satellite data is filtered to exclude timestamps after this
history_mins: How many minutes of history will be used as input features
forecast_mins: How many minutes of future will be used as target features
sample_freq_mins: The sample frequency to use for the satellite data
"""

# Load the sat zarr file or list of files and slice the data to the given period
self.ds = load_satellite_zarrs(zarr_path).sel(time=slice(start_time, end_time))

# Convert the satellite data to the given time frequency by selection
mask = np.mod(self.ds.time.dt.minute, sample_freq_mins)==0
self.ds = self.ds.sel(time=mask)

# Find the valid t0 times for the available data. This avoids trying to take samples where
# there would be a missing timestamp in the sat data required for the sample
self.valid_t0_times = find_valid_t0_times(self.ds, history_mins, forecast_mins, sample_freq_mins)

self.history_mins = history_mins
self.forecast_mins = forecast_mins
self.sample_freq_mins = sample_freq_mins


def __len__(self):
return len(self.valid_t0_times)


def __getitem__(self, idx):
t0 = self.valid_t0_times[idx]

ds_sel = self.ds.sel(
time=slice(
t0-minutes(self.history_mins),
t0+minutes(self.forecast_mins)
)
)

# Load the data eagerly so that the same chunks aren't loaded multiple times after we split
# further
ds_sel = ds_sel.compute(scheduler="single-threaded")

# Reshape to (channel, time, height, width)
ds_sel = ds_sel.transpose("variable", "time", "y_geostationary", "x_geostationary")

ds_input = ds_sel.sel(time=slice(None, t0))
ds_target = ds_sel.sel(time=slice(t0+minutes(self.sample_freq_mins), None))

# Convert to arrays
X = ds_input.data.values
y = ds_target.data.values

return X, y


class SatelliteDataModule(LightningDataModule):
"""A lightning DataModule for loading past and future satellite data"""

def __init__(
self,
zarr_path: Union[list, str],
history_mins: int,
forecast_mins: int,
sample_freq_mins: int,
batch_size=16,
num_workers=0,
prefetch_factor=None,
train_period=[None, None],
val_period=[None, None],
test_period=[None, None],
):
"""A lightning DataModule for loading past and future satellite data
Args:
zarr_path: Path the satellite data. Can be a string or list
history_mins: How many minutes of history will be used as input features
forecast_mins: How many minutes of future will be used as target features
sample_freq_mins: The sample frequency to use for the satellite data
batch_size: Batch size.
num_workers: Number of workers to use in multiprocess batch loading.
prefetch_factor: Number of data will be prefetched at the end of each worker process.
train_period: Date range filter for train dataloader.
val_period: Date range filter for val dataloader.
test_period: Date range filter for test dataloader.
"""
super().__init__()

self.zarr_path = zarr_path
self.history_mins = history_mins
self.forecast_mins = forecast_mins
self.sample_freq_mins = sample_freq_mins
self.train_period = train_period
self.val_period = val_period
self.test_period = test_period

self._common_dataloader_kwargs = dict(
batch_size=batch_size,
sampler=None,
batch_sampler=None,
num_workers=num_workers,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
prefetch_factor=prefetch_factor,
persistent_workers=False,
)

def _make_dataset(self, start_date, end_date):
dataset = SatelliteDataset(
self.zarr_path,
start_date,
end_date,
self.history_mins,
self.forecast_mins,
self.sample_freq_mins,
)
return dataset

def train_dataloader(self):
"""Construct train dataloader"""
dataset = self._make_dataset(*self.train_period)
return DataLoader(dataset, shuffle=True, **self._common_dataloader_kwargs)

def val_dataloader(self):
"""Construct val dataloader"""

dataset = self._make_dataset(*self.val_period)
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)

def test_dataloader(self):
"""Construct test dataloader"""
dataset = self._make_dataset(*self.test_period)
return DataLoader(dataset, shuffle=False, **self._common_dataloader_kwargs)
31 changes: 31 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import os
import tempfile

import pytest
import numpy as np
import xarray as xr

xr.set_options(keep_attrs=True)


@pytest.fixture()
def sat_zarr_path():

# Load dataset which only contains coordinates, but no data
ds = xr.load_dataset(
f"{os.path.dirname(os.path.abspath(__file__))}/test_data/non_hrv_shell.netcdf"
)

# Add data to dataset
ds["data"] = xr.DataArray(
np.zeros([len(ds[c]) for c in ds.coords]),
coords=ds.coords,
)

with tempfile.TemporaryDirectory() as temp_dir:

# Save temporarily as a zarr
zarr_path = f"{temp_dir}/test_sat.zarr"
ds.to_zarr(zarr_path)

yield zarr_path
Binary file added tests/test_data/non_hrv_shell.netcdf
Binary file not shown.
77 changes: 77 additions & 0 deletions tests/test_dataloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from sat_pred.dataloader import (
load_satellite_zarrs,
find_valid_t0_times,
SatelliteDataset,
SatelliteDataModule,
)


def test_load_satellite_zarrs(sat_zarr_path):

# Check can load with string and list of string(s)
ds = load_satellite_zarrs(sat_zarr_path)
ds = load_satellite_zarrs([sat_zarr_path])

# Dataset is a full 24 hours of 5 minutely data -> 24hours * (60/5) = 288
assert len(ds.time)==288


def test_find_valid_t0_times(sat_zarr_path):
ds = load_satellite_zarrs(sat_zarr_path)

t0_times = find_valid_t0_times(
ds,
history_mins=60,
forecast_mins=120,
sample_freq_mins=5,
)

# original timesteps 288
# forecast length buffer - (120 / 5)
# history length buffer - (60 / 5)
# ------------
# Total 252

assert len(t0_times)==252


def test_satellite_dataset(sat_zarr_path):
dataset = SatelliteDataset(
zarr_path=sat_zarr_path,
start_time=None,
end_time=None,
history_mins=60,
forecast_mins=120,
sample_freq_mins=5,
)

assert len(dataset)==252

X, y = dataset[0]

# 11 channels
# 372 y-dim steps
# 614 x-dim steps
# (60 / 5) + 1 = 13 history steps
# (120 / 5) = 24 forecast steps
assert X.shape==(11, 13, 372, 614)
assert y.shape==(11, 24, 372, 614)


def test_satellite_datamodule(sat_zarr_path):
datamodule = SatelliteDataModule(
zarr_path=sat_zarr_path,
history_mins=60,
forecast_mins=120,
sample_freq_mins=5,
batch_size=2,
num_workers=2,
prefetch_factor=None,
)

dl = datamodule.train_dataloader()

X, y = next(iter(dl))

assert X.shape==(2, 11, 13, 372, 614)
assert y.shape==(2, 11, 24, 372, 614)

0 comments on commit fc64df9

Please sign in to comment.