-
-
Notifications
You must be signed in to change notification settings - Fork 47
Refactoring all providers into a single script #413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
This file was deleted.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,108 @@ | ||
| """NWP provider loaders. | ||
|
|
||
| All providers follow the same shape: | ||
| open zarr -> normalise dim/coord names -> shared post-processing. | ||
|
|
||
| `_open_regular_grid_nwp` is the shared tail. Per-provider functions only | ||
| handle the open + renaming step that differs between data sources. | ||
| """ | ||
|
|
||
| import logging | ||
|
|
||
| import xarray as xr | ||
|
|
||
| from ocf_data_sampler.load.nwp.providers.utils import open_zarr_paths | ||
| from ocf_data_sampler.load.utils import ( | ||
| check_time_unique_increasing, | ||
| get_xr_data_array_from_xr_dataset, | ||
| make_spatial_coords_increasing, | ||
| ) | ||
|
|
||
| _log = logging.getLogger(__name__) | ||
|
|
||
|
|
||
| def _open_regular_grid_nwp( | ||
| ds: xr.Dataset | xr.DataArray, | ||
| x_coord: str, | ||
| y_coord: str, | ||
| ) -> xr.DataArray: | ||
| """Shared post-processing for any regular-grid NWP dataset. | ||
|
|
||
| Expects dims/coords already normalised to: init_time_utc, step, channel, | ||
| plus the given x_coord/y_coord spatial dims. | ||
| """ | ||
| check_time_unique_increasing(ds.init_time_utc) | ||
| ds = make_spatial_coords_increasing(ds, x_coord=x_coord, y_coord=y_coord) | ||
| ds = ds.transpose("init_time_utc", "step", "channel", x_coord, y_coord) | ||
|
|
||
| if isinstance(ds, xr.Dataset): | ||
| return get_xr_data_array_from_xr_dataset(ds) | ||
| return ds | ||
|
|
||
|
|
||
| def open_ifs(zarr_path: str | list[str]) -> xr.DataArray: | ||
| """Opens ECMWF IFS / MetOffice Global NWP data.""" | ||
| ds = open_zarr_paths(zarr_path, backend="tensorstore") | ||
| # LEGACY SUPPORT - older zarrs use "init_time"/"variable" dim names | ||
| ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"}) | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think there is one issue actually which is when using open_zarr_paths you have to specify whether to use |
||
| return _open_regular_grid_nwp(ds, x_coord="longitude", y_coord="latitude") | ||
|
|
||
|
|
||
| def open_gdm(zarr_path: str | list[str]) -> xr.DataArray: | ||
| """Opens GDM (e.g. GenCast) NWP data.""" | ||
| ds = open_zarr_paths(zarr_path, backend="tensorstore", time_dim="init_time_utc") | ||
| return _open_regular_grid_nwp(ds, x_coord="longitude", y_coord="latitude") | ||
|
|
||
|
|
||
| def open_gfs(zarr_path: str | list[str], public: bool = False) -> xr.DataArray: | ||
| """Opens GFS NWP data.""" | ||
| _log.info("Loading NWP GFS data") | ||
| ds = open_zarr_paths( | ||
| zarr_path, | ||
| time_dim="init_time_utc", | ||
| public=public, | ||
| backend="dask", | ||
| ) | ||
| nwp = ds.to_array(dim="channel") | ||
| del ds | ||
| return _open_regular_grid_nwp(nwp, x_coord="longitude", y_coord="latitude") | ||
|
|
||
|
|
||
| def open_icon_eu(zarr_path: str | list[str]) -> xr.DataArray: | ||
| """Opens DWD ICON-EU data. | ||
|
|
||
| ICON-EU is expected to be on a regular lat/lon grid with a 'channel' dim. | ||
| Only the first 78 (one-hour) steps are used; the rest are 3-hour steps. | ||
| """ | ||
| ds = open_zarr_paths(zarr_path, time_dim="init_time_utc", backend="dask") | ||
| if "icon_eu_data" not in ds.data_vars: | ||
| raise ValueError("Could not find 'icon_eu_data' DataArray in the ICON-EU Zarr file.") | ||
| nwp = ds["icon_eu_data"].isel(step=slice(0, 78)) | ||
| return _open_regular_grid_nwp(nwp, x_coord="longitude", y_coord="latitude") | ||
|
|
||
|
|
||
| def open_ukv(zarr_path: str | list[str]) -> xr.DataArray: | ||
| """Opens UKV NWP data (OSGB grid).""" | ||
| ds = open_zarr_paths(zarr_path, backend="tensorstore") | ||
| # Only rename keys actually present - new UKV data already uses the target names | ||
| rename_map = { | ||
| "init_time": "init_time_utc", | ||
| "variable": "channel", | ||
| "x": "x_osgb", | ||
| "y": "y_osgb", | ||
| } | ||
| ds = ds.rename({k: v for k, v in rename_map.items() if k in ds.coords}) | ||
| return _open_regular_grid_nwp(ds, x_coord="x_osgb", y_coord="y_osgb") | ||
|
|
||
|
|
||
| def open_cloudcasting(zarr_path: str | list[str]) -> xr.DataArray: | ||
| """Opens OCF cloudcasting satellite-prediction data (geostationary grid). | ||
|
|
||
| References: | ||
| [1] https://www.openclimatefix.org/projects/cloud-forecasting | ||
| [2] https://github.com/ClimeTrend/cloudcasting | ||
| [3] https://github.com/openclimatefix/sat_pred | ||
| """ | ||
| ds = open_zarr_paths(zarr_path, backend="tensorstore") | ||
| ds = ds.rename({"init_time": "init_time_utc", "variable": "channel"}) | ||
| return _open_regular_grid_nwp(ds, x_coord="x_geostationary", y_coord="y_geostationary") | ||
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just a few suggestions on wording so it doesn't get confusing since shape often refers to the shape of the data and normalise to normalisation of data which we do elsewhere in this repo so avoiding those words in here to avoid confusion