diff --git a/ocf_data_sampler/load/nwp/nwp.py b/ocf_data_sampler/load/nwp/nwp.py index 24fca54d..12442112 100755 --- a/ocf_data_sampler/load/nwp/nwp.py +++ b/ocf_data_sampler/load/nwp/nwp.py @@ -5,12 +5,14 @@ import numpy as np import xarray as xr -from ocf_data_sampler.load.nwp.providers.cloudcasting import open_cloudcasting -from ocf_data_sampler.load.nwp.providers.ecmwf import open_ifs -from ocf_data_sampler.load.nwp.providers.gdm import open_gdm -from ocf_data_sampler.load.nwp.providers.gfs import open_gfs -from ocf_data_sampler.load.nwp.providers.icon import open_icon_eu -from ocf_data_sampler.load.nwp.providers.ukv import open_ukv +from ocf_data_sampler.load.nwp.providers.loaders import ( + open_cloudcasting, + open_gdm, + open_gfs, + open_icon_eu, + open_ifs, + open_ukv, +) _OPEN_NWP_FUNCTIONS: dict[str, Callable[..., xr.DataArray]] = { "ukv": open_ukv, diff --git a/ocf_data_sampler/load/nwp/providers/cloudcasting.py b/ocf_data_sampler/load/nwp/providers/cloudcasting.py deleted file mode 100644 index ac0f41c5..00000000 --- a/ocf_data_sampler/load/nwp/providers/cloudcasting.py +++ /dev/null @@ -1,49 +0,0 @@ -"""Cloudcasting provider loader.""" - -import xarray as xr - -from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths -from ocf_data_sampler.load.utils import ( - check_time_unique_increasing, - get_xr_data_array_from_xr_dataset, - make_spatial_coords_increasing, -) - - -def open_cloudcasting(zarr_path: str | list[str]) -> xr.DataArray: - """Opens the satellite predictions from cloudcasting. - - Cloudcasting is a OCF forecast product. We forecast future satellite images from recent - satellite images. More information can be found in the references below. - - Args: - zarr_path: Path to the zarr(s) to open - - Returns: - Xarray DataArray of the cloudcasting data - - References: - [1] https://www.openclimatefix.org/projects/cloud-forecasting - [2] https://github.com/ClimeTrend/cloudcasting - [3] https://github.com/openclimatefix/sat_pred - """ - # Open the data - ds = open_zarr_paths(zarr_path, backend="tensorstore") - - # Rename - ds = ds.rename( - { - "init_time": "init_time_utc", - "variable": "channel", - }, - ) - - # Check the timestamps are unique and increasing - check_time_unique_increasing(ds.init_time_utc) - - # Make sure the spatial coords are in increasing order - ds = make_spatial_coords_increasing(ds, x_coord="x_geostationary", y_coord="y_geostationary") - - ds = ds.transpose("init_time_utc", "step", "channel", "x_geostationary", "y_geostationary") - - return get_xr_data_array_from_xr_dataset(ds) diff --git a/ocf_data_sampler/load/nwp/providers/ecmwf.py b/ocf_data_sampler/load/nwp/providers/ecmwf.py deleted file mode 100755 index 2a7cc095..00000000 --- a/ocf_data_sampler/load/nwp/providers/ecmwf.py +++ /dev/null @@ -1,34 +0,0 @@ -"""ECMWF provider loaders.""" - -import xarray as xr - -from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths -from ocf_data_sampler.load.utils import ( - check_time_unique_increasing, - get_xr_data_array_from_xr_dataset, - make_spatial_coords_increasing, -) - - -def open_ifs(zarr_path: str | list[str]) -> xr.DataArray: - """Opens the ECMWF IFS NWP data. - - Args: - zarr_path: Path to the zarr(s) to open - - Returns: - Xarray DataArray of the NWP data - """ - ds = open_zarr_paths(zarr_path, backend="tensorstore") - - # LEGACY SUPPORT - rename variable to channel if it exists - ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"}) - - check_time_unique_increasing(ds.init_time_utc) - - ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude") - - ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude") - - # TODO: should we control the dtype of the DataArray? - return get_xr_data_array_from_xr_dataset(ds) diff --git a/ocf_data_sampler/load/nwp/providers/gdm.py b/ocf_data_sampler/load/nwp/providers/gdm.py deleted file mode 100644 index 7e7e355e..00000000 --- a/ocf_data_sampler/load/nwp/providers/gdm.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Loader for any GDM WP data.""" - -import xarray as xr - -from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths -from ocf_data_sampler.load.utils import ( - check_time_unique_increasing, - get_xr_data_array_from_xr_dataset, - make_spatial_coords_increasing, -) - - -def open_gdm(zarr_path: str | list[str]) -> xr.DataArray: - """Opens GDM NWP data. - - Args: - zarr_path: Path to the zarr(s) to open - - Returns: - Xarray DataArray of the NWP data - """ - ds = open_zarr_paths(zarr_path, backend="tensorstore", time_dim="init_time_utc") - - check_time_unique_increasing(ds.init_time_utc) - - ds = make_spatial_coords_increasing(ds, x_coord="longitude", y_coord="latitude") - - ds = ds.transpose("init_time_utc", "step", "channel", "longitude", "latitude") - - # TODO: should we control the dtype of the DataArray? - return get_xr_data_array_from_xr_dataset(ds) diff --git a/ocf_data_sampler/load/nwp/providers/gfs.py b/ocf_data_sampler/load/nwp/providers/gfs.py deleted file mode 100644 index 4b9d91b3..00000000 --- a/ocf_data_sampler/load/nwp/providers/gfs.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Open GFS Forecast data.""" - -import logging - -import xarray as xr - -from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths -from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing - -_log = logging.getLogger(__name__) - - -def open_gfs(zarr_path: str | list[str], public: bool = False) -> xr.DataArray: - """Opens the GFS data. - - Args: - zarr_path: Path to the zarr(s) to open - public: Whether the data is public or private - - Returns: - Xarray DataArray of the NWP data - """ - _log.info("Loading NWP GFS data") - - # Open data - gfs: xr.Dataset = open_zarr_paths( - zarr_path, - time_dim="init_time_utc", - public=public, - backend="dask", - ) - nwp: xr.DataArray = gfs.to_array(dim="channel") - - del gfs - - check_time_unique_increasing(nwp.init_time_utc) - nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude") - - nwp = nwp.transpose("init_time_utc", "step", "channel", "longitude", "latitude") - - return nwp diff --git a/ocf_data_sampler/load/nwp/providers/icon.py b/ocf_data_sampler/load/nwp/providers/icon.py deleted file mode 100644 index 606dfae7..00000000 --- a/ocf_data_sampler/load/nwp/providers/icon.py +++ /dev/null @@ -1,37 +0,0 @@ -"""DWD ICON Loading.""" - -import xarray as xr - -from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths -from ocf_data_sampler.load.utils import check_time_unique_increasing, make_spatial_coords_increasing - - -def open_icon_eu(zarr_path: str | list[str]) -> xr.DataArray: - """Opens the ICON data. - - ICON EU Data is now expected to be on a regular lat/lon grid, - with a 'channel' dimension directly available (as per the updated fixture). - The 'isobaricInhPa' dimension is expected to be already handled. - - Args: - zarr_path: Path to the zarr(s) to open - - Returns: - Xarray DataArray of the NWP data - """ - # Open and check initially - ds = open_zarr_paths(zarr_path, time_dim="init_time_utc", backend="dask") - - if "icon_eu_data" in ds.data_vars: - nwp = ds["icon_eu_data"] - else: - raise ValueError("Could not find 'icon_eu_data' DataArray in the ICON-EU Zarr file.") - - check_time_unique_increasing(nwp.init_time_utc) - - # 0-78 one hour steps, rest 3 hour steps - nwp = nwp.isel(step=slice(0, 78)) - nwp = nwp.transpose("init_time_utc", "step", "channel", "longitude", "latitude") - nwp = make_spatial_coords_increasing(nwp, x_coord="longitude", y_coord="latitude") - - return nwp diff --git a/ocf_data_sampler/load/nwp/providers/loaders.py b/ocf_data_sampler/load/nwp/providers/loaders.py new file mode 100644 index 00000000..abb2a937 --- /dev/null +++ b/ocf_data_sampler/load/nwp/providers/loaders.py @@ -0,0 +1,108 @@ +"""NWP provider loaders. + +All providers follow the same shape: + open zarr -> normalise dim/coord names -> shared post-processing. + +`_open_regular_grid_nwp` is the shared tail. Per-provider functions only +handle the open + renaming step that differs between data sources. +""" + +import logging + +import xarray as xr + +from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths +from ocf_data_sampler.load.utils import ( + check_time_unique_increasing, + get_xr_data_array_from_xr_dataset, + make_spatial_coords_increasing, +) + +_log = logging.getLogger(__name__) + + +def _open_regular_grid_nwp( + ds: xr.Dataset | xr.DataArray, + x_coord: str, + y_coord: str, +) -> xr.DataArray: + """Shared post-processing for any regular-grid NWP dataset. + + Expects dims/coords already normalised to: init_time_utc, step, channel, + plus the given x_coord/y_coord spatial dims. + """ + check_time_unique_increasing(ds.init_time_utc) + ds = make_spatial_coords_increasing(ds, x_coord=x_coord, y_coord=y_coord) + ds = ds.transpose("init_time_utc", "step", "channel", x_coord, y_coord) + + if isinstance(ds, xr.Dataset): + return get_xr_data_array_from_xr_dataset(ds) + return ds + + +def open_ifs(zarr_path: str | list[str]) -> xr.DataArray: + """Opens ECMWF IFS / MetOffice Global NWP data.""" + ds = open_zarr_paths(zarr_path, backend="tensorstore") + # LEGACY SUPPORT - older zarrs use "init_time"/"variable" dim names + ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"}) + return _open_regular_grid_nwp(ds, x_coord="longitude", y_coord="latitude") + + +def open_gdm(zarr_path: str | list[str]) -> xr.DataArray: + """Opens GDM (e.g. GenCast) NWP data.""" + ds = open_zarr_paths(zarr_path, backend="tensorstore", time_dim="init_time_utc") + return _open_regular_grid_nwp(ds, x_coord="longitude", y_coord="latitude") + + +def open_gfs(zarr_path: str | list[str], public: bool = False) -> xr.DataArray: + """Opens GFS NWP data.""" + _log.info("Loading NWP GFS data") + ds = open_zarr_paths( + zarr_path, + time_dim="init_time_utc", + public=public, + backend="dask", + ) + nwp = ds.to_array(dim="channel") + del ds + return _open_regular_grid_nwp(nwp, x_coord="longitude", y_coord="latitude") + + +def open_icon_eu(zarr_path: str | list[str]) -> xr.DataArray: + """Opens DWD ICON-EU data. + + ICON-EU is expected to be on a regular lat/lon grid with a 'channel' dim. + Only the first 78 (one-hour) steps are used; the rest are 3-hour steps. + """ + ds = open_zarr_paths(zarr_path, time_dim="init_time_utc", backend="dask") + if "icon_eu_data" not in ds.data_vars: + raise ValueError("Could not find 'icon_eu_data' DataArray in the ICON-EU Zarr file.") + nwp = ds["icon_eu_data"].isel(step=slice(0, 78)) + return _open_regular_grid_nwp(nwp, x_coord="longitude", y_coord="latitude") + + +def open_ukv(zarr_path: str | list[str]) -> xr.DataArray: + """Opens UKV NWP data (OSGB grid).""" + ds = open_zarr_paths(zarr_path, backend="tensorstore") + # Only rename keys actually present - new UKV data already uses the target names + rename_map = { + "init_time": "init_time_utc", + "variable": "channel", + "x": "x_osgb", + "y": "y_osgb", + } + ds = ds.rename({k: v for k, v in rename_map.items() if k in ds.coords}) + return _open_regular_grid_nwp(ds, x_coord="x_osgb", y_coord="y_osgb") + + +def open_cloudcasting(zarr_path: str | list[str]) -> xr.DataArray: + """Opens OCF cloudcasting satellite-prediction data (geostationary grid). + + References: + [1] https://www.openclimatefix.org/projects/cloud-forecasting + [2] https://github.com/ClimeTrend/cloudcasting + [3] https://github.com/openclimatefix/sat_pred + """ + ds = open_zarr_paths(zarr_path, backend="tensorstore") + ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"}) + return _open_regular_grid_nwp(ds, x_coord="x_geostationary", y_coord="y_geostationary") diff --git a/ocf_data_sampler/load/nwp/providers/ukv.py b/ocf_data_sampler/load/nwp/providers/ukv.py deleted file mode 100755 index 21fa1c2f..00000000 --- a/ocf_data_sampler/load/nwp/providers/ukv.py +++ /dev/null @@ -1,42 +0,0 @@ -"""UKV provider loaders.""" - -import xarray as xr - -from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths -from ocf_data_sampler.load.utils import ( - check_time_unique_increasing, - get_xr_data_array_from_xr_dataset, - make_spatial_coords_increasing, -) - - -def open_ukv(zarr_path: str | list[str]) -> xr.DataArray: - """Opens the NWP data. - - Args: - zarr_path: Path to the zarr(s) to open - - Returns: - Xarray DataArray of the NWP data - """ - ds = open_zarr_paths(zarr_path, backend="tensorstore") - - # Define the desired mapping - rename_map = { - "init_time": "init_time_utc", - "variable": "channel", - "x": "x_osgb", - "y": "y_osgb", - } - - # Only rename if the source key exists in the dataset's dimensions or coordinates - # This prevents KeyErrors when the new UKV data already has "x_osgb" and "y_osgb" - ds = ds.rename({k: v for k, v in rename_map.items() if k in ds.coords}) - check_time_unique_increasing(ds.init_time_utc) - - ds = make_spatial_coords_increasing(ds, x_coord="x_osgb", y_coord="y_osgb") - - ds = ds.transpose("init_time_utc", "step", "channel", "x_osgb", "y_osgb") - - # TODO: should we control the dtype of the DataArray? - return get_xr_data_array_from_xr_dataset(ds)