Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions src/oceanum/datamesh/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(
self,
token=None,
service=os.environ.get("DATAMESH_SERVICE", DEFAULT_CONFIG["DATAMESH_SERVICE"]),
gateway=os.environ.get("DATAMESH_GATEWAY", None),
_gateway=os.environ.get("DATAMESH_GATEWAY", None),
user=None,
session_duration=None,
verify=True,
Expand All @@ -91,6 +91,8 @@ def __init__(

Args:
token (string): Your datamesh access token. Defaults to os.environ.get("DATAMESH_TOKEN", None).
service (string): The datamesh service url. Defaults to os.environ.get("DATAMESH_SERVICE", "https://datamesh.oceanum.io").
user (string, optional): Optional user identifier to be sent in the header for datamesh authentication. Defaults to None.
session_duration (float, optional): The desired length of time for acquired datamesh sessions in seconds. Will be 3600 seconds by default.
verify (bool, optional): Whether to verify the datamesh server certificate. Defaults to True.
Raises:
Expand All @@ -108,7 +110,12 @@ def __init__(
self._session_params = (
{"duration": float(session_duration)} if session_duration else {}
)
self._gateway = gateway
if _gateway and re.match(r"^https?://gateway\.datamesh(-v0)?\.oceanum\.(io|tech)", _gateway):
warnings.warn(
f"The gateway url {_gateway} is deprecated. Please use https://datamesh.oceanum.io or https://datamesh.oceanum.tech instead.",
DeprecationWarning,
)
self._gateway = _gateway or f"{self._proto}://{self._host}"
self._cachedir = tempfile.TemporaryDirectory(prefix="datamesh_")
self._verify = verify

Expand Down Expand Up @@ -170,7 +177,6 @@ def _check_info(self):
Typically will ask to update the client if the version is outdated.
Also will set gateway address to service address if not provided.
"""
self._gateway = self._gateway or f"{self._proto}://{self._host}"
try:
resp = self._retried_request(
f"{self._gateway}/info/oceanum_python/{__version__}",
Expand All @@ -185,7 +191,7 @@ def _check_info(self):
f"Failed to reach datamesh: {resp.status_code}-{resp.text}"
)
except Exception as e:
warnings.warn(f"Failed to reach datamesh gateway at {self._gateway}: {e}")
warnings.warn(f"Failed to check status of datamesh gateway at {self._gateway}: {e}")

def _validate_response(self, resp):
if resp.status_code >= 400:
Expand Down
22 changes: 21 additions & 1 deletion src/oceanum/datamesh/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ def zarr_write(
overwrite=False,
group: Optional[str] = None,
):
def _is_monotonic_non_decreasing(values) -> bool:
if len(values) < 2:
return True
return bool(numpy.all(values[:-1] <= values[1:]))

with Session.acquire(connection) as session:
store = ZarrClient(connection, datasource_id, session, api="zarr", nocache=True)
if overwrite is True:
Expand All @@ -229,11 +234,20 @@ def zarr_write(
raise DatameshWriteError(
f"Append coordinate {append} has more than one dimension"
)
cnew = data[append]
if not _is_monotonic_non_decreasing(cnew.values):
raise DatameshWriteError(
f"Append coordinate {append} in incoming data must be monotonic non-decreasing"
)
append_dim = cexist.dims[0]
(replace_range,) = numpy.nonzero(
((cexist >= data[append][0]) & (cexist <= data[append][-1])).values
) # Get range in new data which overlaps - this just replaces everything >= first value in the new data
if len(replace_range):
if not numpy.all(numpy.diff(replace_range) == 1):
raise DatameshWriteError(
f"Cannot append on coordinate {append}: overlapping indices in existing zarr are non-contiguous (existing coordinate likely non-monotonic)"
)
# Fail if the replacement range is larger than incomign data
if len(replace_range) > len(data[append]):
raise DatameshWriteError(
Expand All @@ -246,8 +260,14 @@ def zarr_write(
]
replace_section = data.isel(
**{append_dim: slice(0, len(replace_range))}
).drop(drop_coords + drop_vars)
).drop_vars(drop_coords + drop_vars, errors="ignore")
replace_slice = slice(replace_range[0], replace_range[-1] + 1)
replace_coord = replace_section[append]
existing_coord = cexist[replace_slice]
if not numpy.array_equal(replace_coord.values, existing_coord.values):
raise DatameshWriteError(
f"Cannot append on coordinate {append}: overlap timestamps do not match existing archive values. Inserting new timestamps into an existing coordinate range is not supported"
)
# Fail if we are replacing an internal section and ends of coordinates do not match
if replace_range[-1] + 1 < len(cexist) and not numpy.array_equal(
replace_section[append], cexist[replace_slice]
Expand Down
86 changes: 86 additions & 0 deletions tests/test_zarr_append.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Unit tests for zarr_write append validation."""
import pytest
from unittest.mock import Mock, patch, MagicMock
import numpy as np
import xarray as xr

from oceanum.datamesh.zarr import zarr_write
from oceanum.datamesh.exceptions import DatameshWriteError


def _append_conn(ds_exists=True):
conn = Mock(_gateway="http://test", _auth_headers={}, _is_v1=True)
ds = Mock(_exists=ds_exists)
ds.dataschema = Mock(coords={"time": {}})
conn.get_datasource = Mock(return_value=ds)
return conn


def _session_mock():
session_mock = MagicMock()
session_mock.__enter__ = Mock(return_value=session_mock)
session_mock.__exit__ = Mock(return_value=False)
session_mock.add_header = lambda h: h
return session_mock


def test_zarr_write_append_rejects_non_monotonic_incoming_coordinate():
conn = _append_conn(ds_exists=True)
session_mock = _session_mock()

existing = xr.Dataset(
data_vars={"incli": ("time", np.arange(5))},
coords={"time": np.array([1, 2, 3, 4, 5])},
)
incoming = xr.Dataset(
data_vars={"incli": ("time", np.arange(4))},
coords={"time": np.array([3, 5, 4, 6])},
)

with patch("oceanum.datamesh.zarr.Session.acquire", return_value=session_mock):
with patch("oceanum.datamesh.zarr.ZarrClient"):
with patch("oceanum.datamesh.zarr.xarray.open_zarr", return_value=existing):
with pytest.raises(DatameshWriteError, match="must be monotonic non-decreasing"):
zarr_write(conn, "test-ds", incoming, append="time")


def test_zarr_write_append_rejects_non_contiguous_overlap_indices():
conn = _append_conn(ds_exists=True)
session_mock = _session_mock()

# Existing time is non-monotonic, causing overlap indices to be non-contiguous.
existing = xr.Dataset(
data_vars={"incli": ("time", np.arange(6))},
coords={"time": np.array([1, 2, 3, 10, 4, 5])},
)
incoming = xr.Dataset(
data_vars={"incli": ("time", np.arange(4))},
coords={"time": np.array([3, 4, 5, 6])},
)

with patch("oceanum.datamesh.zarr.Session.acquire", return_value=session_mock):
with patch("oceanum.datamesh.zarr.ZarrClient"):
with patch("oceanum.datamesh.zarr.xarray.open_zarr", return_value=existing):
with pytest.raises(DatameshWriteError, match="non-contiguous"):
zarr_write(conn, "test-ds", incoming, append="time")


def test_zarr_write_append_rejects_overlap_timestamp_mismatch():
conn = _append_conn(ds_exists=True)
session_mock = _session_mock()

existing = xr.Dataset(
data_vars={"incli": ("time", np.arange(6))},
coords={"time": np.array([1, 2, 3, 4, 5, 6])},
)
# Overlap bounds [3, 9] exist, but first overlap timestamps do not match [3,4,5,6].
incoming = xr.Dataset(
data_vars={"incli": ("time", np.arange(7))},
coords={"time": np.array([3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5])},
)

with patch("oceanum.datamesh.zarr.Session.acquire", return_value=session_mock):
with patch("oceanum.datamesh.zarr.ZarrClient"):
with patch("oceanum.datamesh.zarr.xarray.open_zarr", return_value=existing):
with pytest.raises(DatameshWriteError, match="overlap timestamps do not match"):
zarr_write(conn, "test-ds", incoming, append="time")
Loading