Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/access_moppy/atmosphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,9 +222,23 @@ def select_and_process_variables(self):
result = evaluate_expression(calc, context)

# Check whether the time interval/frequency has changed (e.g. daily → monthly)
if "time" in result.dims and result.sizes["time"] != self.ds.sizes.get(
"time", result.sizes["time"]
):
result_has_time = "time" in result.dims
time_size_changed = result_has_time and result.sizes[
"time"
] != self.ds.sizes.get("time", result.sizes["time"])
# Even when sizes match, assignment can align by coordinate labels.
# If formula changes time labels (e.g. month-start -> month-midpoint),
# direct assignment would reindex to NaN and later become 1e20.
time_coord_changed = False
if result_has_time and not time_size_changed and "time" in self.ds.coords:
try:
time_coord_changed = not np.array_equal(
result["time"].values, self.ds["time"].values
)
except Exception:
time_coord_changed = True

if time_size_changed or time_coord_changed:
# If the temporal resolution changes, rebuild self.ds while preserving variables that are not time-dependent
time_indep = {
v: self.ds[v]
Expand Down
59 changes: 59 additions & 0 deletions src/access_moppy/derivations/calc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,61 @@ def _monthly_midpoint_coord(time_da: xr.DataArray) -> xr.DataArray:
return time_da.copy(data=midpoints)


def _mask_missing_values_for_reduction(da: xr.DataArray) -> xr.DataArray:
"""Mask configured missing-value sentinels so reductions ignore them.

Input files can carry numeric sentinels (for example ``1e20``) in data while
storing marker values in attrs/encoding as ``_FillValue`` or ``missing_value``.
If those sentinels are not masked, temporal maxima can collapse to the marker
value. This helper applies a lazy mask using xarray operations, preserving
Dask-backed arrays.
"""

def _iter_markers(value):
if value is None:
return
if np.isscalar(value):
yield value
return
for v in np.ravel(value):
yield v

markers = []
has_nan_marker = False
for container in (da.attrs, da.encoding):
for key in ("missing_value", "_FillValue"):
for raw in _iter_markers(container.get(key)):
try:
marker = float(raw)
except (TypeError, ValueError):
continue
if np.isnan(marker):
has_nan_marker = True
else:
markers.append(marker)

mask = None
if np.issubdtype(da.dtype, np.floating) and has_nan_marker:
mask = np.isnan(da)

for marker in set(markers):
if np.isfinite(marker):
# Match both exact values and float32-rounded encodings (e.g. 1e20).
atol = max(1e-12, abs(float(np.spacing(np.float32(marker)))))
condition = np.isclose(da, marker, rtol=0.0, atol=atol)
else:
condition = da == marker
mask = condition if mask is None else (mask | condition)

if mask is None:
return da

masked = da.where(~mask)
masked.attrs = da.attrs.copy()
masked.encoding = da.encoding.copy()
return masked


def calculate_monthly_minimum(
da: xr.DataArray, time_dim: str = "time", preserve_attrs: bool = True
) -> xr.DataArray:
Expand Down Expand Up @@ -262,6 +317,8 @@ def calculate_monthly_minimum(
_name = da.name or "__tmp"
da = xr.decode_cf(da.to_dataset(name=_name))[_name]

da = _mask_missing_values_for_reduction(da)

try:
monthly_min = da.resample({time_dim: "ME"}).min(keep_attrs=preserve_attrs)
# "ME" labels each bin at month-end; recentre to the cell midpoint
Expand Down Expand Up @@ -362,6 +419,8 @@ def calculate_monthly_maximum(
_name = da.name or "__tmp"
da = xr.decode_cf(da.to_dataset(name=_name))[_name]

da = _mask_missing_values_for_reduction(da)

try:
monthly_max = da.resample({time_dim: "ME"}).max(keep_attrs=preserve_attrs)
# "ME" labels each bin at month-end; recentre to the cell midpoint
Expand Down
75 changes: 75 additions & 0 deletions tests/unit/test_atmosphere.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,81 @@ def test_formula_same_time_length_uses_setitem(self):

assert cmoriser.ds["tasmax"].sizes["time"] == 12

@pytest.mark.unit
def test_formula_same_time_length_but_shifted_labels_rebuilds(self):
"""Shifted time labels must not be aligned away to all-NaN values."""
monthly_time = pd.date_range("2020-01-01", periods=12, freq="MS")
monthly_ds = xr.Dataset(
{
"tasmax": xr.DataArray(
np.random.default_rng(5).normal(305, 5, 12),
dims=["time"],
coords={"time": monthly_time},
attrs={"units": "K"},
)
}
)
monthly_ds["time"].attrs = {"units": "days since 1850-01-01"}

shifted_time = pd.date_range("2020-01-16", periods=12, freq="MS")
shifted_result = xr.DataArray(
np.linspace(290.0, 301.0, 12),
dims=["time"],
coords={"time": shifted_time},
)

cmoriser = _make_cmoriser_for_formula(monthly_ds)

with patch(
"access_moppy.atmosphere.evaluate_expression",
return_value=shifted_result,
):
cmoriser.select_and_process_variables()

np.testing.assert_allclose(cmoriser.ds["tasmax"].values, shifted_result.values)
assert np.array_equal(cmoriser.ds["time"].values, shifted_result["time"].values)

@pytest.mark.unit
def test_formula_time_compare_exception_falls_back_to_rebuild(self):
"""If time-label comparison errors, fallback should still rebuild dataset."""
monthly_time = pd.date_range("2020-01-01", periods=12, freq="MS")
monthly_ds = xr.Dataset(
{
"tasmax": xr.DataArray(
np.random.default_rng(6).normal(305, 5, 12),
dims=["time"],
coords={"time": monthly_time},
attrs={"units": "K"},
)
}
)
monthly_ds["time"].attrs = {"units": "days since 1850-01-01"}

same_size_result = xr.DataArray(
np.linspace(280.0, 291.0, 12),
dims=["time"],
coords={"time": monthly_time},
)

cmoriser = _make_cmoriser_for_formula(monthly_ds)

with (
patch(
"access_moppy.atmosphere.evaluate_expression",
return_value=same_size_result,
),
patch(
"access_moppy.atmosphere.np.array_equal",
side_effect=RuntimeError("boom"),
),
):
cmoriser.select_and_process_variables()

np.testing.assert_allclose(
cmoriser.ds["tasmax"].values,
same_size_result.values,
)


class TestSoilDepthDimension:
"""
Expand Down
52 changes: 52 additions & 0 deletions tests/unit/test_derivations_calc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pytest
import xarray as xr

from access_moppy.derivations import calc_utils as calc_utils_mod
from access_moppy.derivations.calc_utils import (
add_axis,
calculate_monthly_maximum,
Expand Down Expand Up @@ -358,6 +359,17 @@ def test_values_are_monthly_maxima(self):
# January maximum should be 99.0
assert float(result.values[0]) == pytest.approx(99.0)

@pytest.mark.unit
def test_ignores_fill_value_marker_in_maximum(self):
times = xr.date_range("2000-01-01", periods=30, freq="D")
data = np.linspace(10.0, 20.0, 30, dtype=np.float32)
# Typical CF sentinel that otherwise dominates monthly maximum.
data[5] = np.float32(1e20)
da = xr.DataArray(data, dims=["time"], coords={"time": times})
da.attrs["_FillValue"] = 1e20
result = calculate_monthly_maximum(da)
assert float(result.values[0]) == pytest.approx(20.0, rel=1e-6)

@pytest.mark.unit
def test_raises_for_missing_time_dim(self):
da = xr.DataArray(np.ones(4), dims=["lat"])
Expand Down Expand Up @@ -420,6 +432,46 @@ def test_resample_failure_raises_runtime_error(self):
calculate_monthly_maximum(da)


class TestMaskMissingValuesForReduction:
@pytest.mark.unit
def test_no_markers_returns_input_unchanged(self):
da = xr.DataArray(np.array([1.0, 2.0, 3.0]), dims=["time"])
out = calc_utils_mod._mask_missing_values_for_reduction(da)
np.testing.assert_array_equal(out.values, da.values)

@pytest.mark.unit
def test_masks_markers_from_encoding_and_iterable_fill_values(self):
da = xr.DataArray(np.array([1.0, 99.0, np.inf]), dims=["time"])
da.encoding["missing_value"] = 99.0
# Exercise iterable marker path and non-finite marker comparison path.
da.attrs["_FillValue"] = np.array([np.inf])

out = calc_utils_mod._mask_missing_values_for_reduction(da)

assert np.isnan(out.values[1])
assert np.isnan(out.values[2])
assert float(out.values[0]) == pytest.approx(1.0)

@pytest.mark.unit
def test_nan_marker_masks_existing_nans(self):
da = xr.DataArray(np.array([1.0, np.nan, 3.0]), dims=["time"])
da.attrs["missing_value"] = np.nan

out = calc_utils_mod._mask_missing_values_for_reduction(da)

assert np.isnan(out.values[1])
assert float(out.values[0]) == pytest.approx(1.0)

@pytest.mark.unit
def test_ignores_non_numeric_marker_values(self):
da = xr.DataArray(np.array([4.0, 5.0, 6.0]), dims=["time"])
da.attrs["_FillValue"] = "not-a-number"

out = calc_utils_mod._mask_missing_values_for_reduction(da)

np.testing.assert_array_equal(out.values, da.values)


# ---------------------------------------------------------------------------
# Monthly time-coordinate midpoint (CF/CMIP6 "time squareness", TIME001)
# ---------------------------------------------------------------------------
Expand Down