diff --git a/.gitignore b/.gitignore index 96ae21f0..df327d60 100644 --- a/.gitignore +++ b/.gitignore @@ -153,3 +153,6 @@ requirements.txt src/app/requirements.txt src/requirements.txt src/requirements-dev.txt + +# Local working notes (drafts of external messages, scratch files) +tmp/ diff --git a/src/app/api/api_v1/endpoints/alerts.py b/src/app/api/api_v1/endpoints/alerts.py index fc156c8b..3e6e17ee 100644 --- a/src/app/api/api_v1/endpoints/alerts.py +++ b/src/app/api/api_v1/endpoints/alerts.py @@ -20,7 +20,7 @@ from app.core.time import utcnow from app.crud import AlertCRUD from app.db import get_session -from app.models import Alert, AlertSequence, Camera, Sequence, UserRole +from app.models import Alert, AlertSequence, AnnotationType, Camera, Sequence, UserRole from app.schemas.alerts import AlertReadWithSequences from app.schemas.login import TokenPayload from app.schemas.sequences import SequenceRead @@ -88,33 +88,107 @@ def _serialize_alert( ) -_ALERT_EXPORT_COLUMNS = ["id", "lat", "lon", "started_at", "last_seen_at"] - - -def _iter_alerts_csv(alerts: Iterable[Alert]) -> Iterator[str]: +_ALERT_EXPORT_COLUMNS = [ + "alert_id", + "alert_started_at_date", + "alert_started_at_time", + "alert_last_seen_at", + "alert_duration_seconds", + "alert_triangulated_lat", + "alert_triangulated_lon", + "organization_id", + "sequence_id", + "sequence_started_at", + "sequence_last_seen_at", + "sequence_triangulated_azimuth", + "sequence_label", + "pose_id", + "camera_id", + "camera_name", +] + +_WILDFIRE_LABELS: Dict[Union[AnnotationType, None], str] = { + AnnotationType.WILDFIRE_SMOKE: "wildfire", + AnnotationType.OTHER_SMOKE: "other", + AnnotationType.OTHER: "other", + None: "unknown", +} + + +async def _fetch_camera_names_by_ids(session: AsyncSession, camera_ids: Iterable[int]) -> Dict[int, str]: + ids = list(set(camera_ids)) + if not ids: + return {} + stmt: Any = select(Camera.id, Camera.name).where(cast(Any, Camera.id).in_(ids)) + return {cid: name for cid, name in (await session.exec(stmt)).all()} + + +def _alert_cells(alert: Alert) -> List[Any]: + return [ + alert.id, + alert.started_at.date().isoformat(), + alert.started_at.time().isoformat(), + alert.last_seen_at.isoformat(), + int((alert.last_seen_at - alert.started_at).total_seconds()), + "" if alert.lat is None else alert.lat, + "" if alert.lon is None else alert.lon, + alert.organization_id, + ] + + +def _sequence_cells(sequence: Sequence, camera_name: str) -> List[Any]: + return [ + sequence.id, + sequence.started_at.isoformat(), + sequence.last_seen_at.isoformat(), + "" if sequence.sequence_azimuth is None else sequence.sequence_azimuth, + _WILDFIRE_LABELS[sequence.is_wildfire], + "" if sequence.pose_id is None else sequence.pose_id, + sequence.camera_id, + camera_name, + ] + + +def _iter_alerts_csv( + alerts: Iterable[Alert], + seq_map: Dict[int, List[Sequence]], + camera_names_by_id: Dict[int, str], +) -> Iterator[str]: buf = io.StringIO() writer = csv.writer(buf) - writer.writerow(_ALERT_EXPORT_COLUMNS) - yield buf.getvalue() - buf.seek(0) - buf.truncate(0) - for a in alerts: - writer.writerow([ - a.id, - "" if a.lat is None else a.lat, - "" if a.lon is None else a.lon, - a.started_at.isoformat(), - a.last_seen_at.isoformat(), - ]) - yield buf.getvalue() + + def drain() -> str: + value = buf.getvalue() buf.seek(0) buf.truncate(0) + return value - -def _build_alerts_csv_response(alerts: List[Alert], from_date: date, to_date: date) -> StreamingResponse: + writer.writerow(_ALERT_EXPORT_COLUMNS) + yield drain() + + for alert in alerts: + alert_cells = _alert_cells(alert) + sequences = sorted(seq_map.get(alert.id, []), key=lambda s: s.started_at) + for sequence in sequences: + camera_name = camera_names_by_id.get(sequence.camera_id, "") + writer.writerow([*alert_cells, *_sequence_cells(sequence, camera_name)]) + yield drain() + + +def _build_alerts_csv_response( + alerts: List[Alert], + seq_map: Dict[int, List[Sequence]], + camera_names_by_id: Dict[int, str], + from_date: date, + to_date: date, +) -> StreamingResponse: filename = f"alerts_{from_date.isoformat()}_{to_date.isoformat()}.csv" headers = {"Content-Disposition": f'attachment; filename="{filename}"'} - return StreamingResponse(_iter_alerts_csv(alerts), media_type="text/csv", headers=headers) + return StreamingResponse( + _iter_alerts_csv(alerts, seq_map, camera_names_by_id), + media_type="text/csv", + headers=headers, + ) @router.get( @@ -152,7 +226,13 @@ async def export_alerts_csv( .where(Alert.started_at <= end_dt) .order_by(Alert.started_at.asc()) # type: ignore[attr-defined] ) - return _build_alerts_csv_response(list((await session.exec(stmt)).all()), from_date, to_date) + alerts = list((await session.exec(stmt)).all()) + seq_map = await _fetch_sequences_by_alert_ids(session, [alert.id for alert in alerts]) + camera_names_by_id = await _fetch_camera_names_by_ids( + session, + (sequence.camera_id for sequences in seq_map.values() for sequence in sequences), + ) + return _build_alerts_csv_response(alerts, seq_map, camera_names_by_id, from_date, to_date) @router.get("/{alert_id}", status_code=status.HTTP_200_OK, summary="Fetch the information of a specific alert") diff --git a/src/tests/endpoints/test_alerts.py b/src/tests/endpoints/test_alerts.py index feb8005c..ad1e55cc 100644 --- a/src/tests/endpoints/test_alerts.py +++ b/src/tests/endpoints/test_alerts.py @@ -6,7 +6,7 @@ import csv import io from datetime import datetime, timedelta -from typing import Any, List, Tuple, cast +from typing import Any, Dict, List, Tuple, cast import pandas as pd import pytest # type: ignore @@ -14,6 +14,7 @@ from sqlmodel import select from sqlmodel.ext.asyncio.session import AsyncSession +from app.api.api_v1.endpoints.alerts import _ALERT_EXPORT_COLUMNS, _iter_alerts_csv from app.core.config import settings from app.core.time import utcnow from app.models import Alert, AlertSequence, AnnotationType, Camera, Detection, Organization, Pose, Sequence @@ -399,136 +400,280 @@ async def _create_alert( return alert -def _parse_csv_body(body: str) -> Tuple[List[str], List[List[str]]]: - reader = csv.reader(io.StringIO(body)) +async def _attach_sequence( + session: AsyncSession, + alert: Alert, + *, + camera_id: int = 1, + is_wildfire: AnnotationType | None = None, + sequence_azimuth: float | None = 100.0, + pose_id: int | None = None, + started_at: datetime | None = None, + last_seen_at: datetime | None = None, +) -> Sequence: + seq = Sequence( + camera_id=camera_id, + pose_id=pose_id, + camera_azimuth=100.0, + is_wildfire=is_wildfire, + sequence_azimuth=sequence_azimuth, + cone_angle=1.0, + started_at=started_at or alert.started_at, + last_seen_at=last_seen_at or alert.last_seen_at, + ) + session.add(seq) + await session.commit() + await session.refresh(seq) + session.add(AlertSequence(alert_id=alert.id, sequence_id=seq.id)) + await session.commit() + return seq + + +def _parse_export_csv(body: str) -> Tuple[List[str], List[Dict[str, str]]]: + reader = csv.DictReader(io.StringIO(body)) rows = list(reader) - return rows[0], rows[1:] + return list(reader.fieldnames or []), rows -@pytest.mark.asyncio -async def test_alerts_export_happy_path(async_client: AsyncClient, detection_session: AsyncSession): - base = datetime(2026, 4, 10, 12, 0, 0) - alerts = [ - await _create_alert(detection_session, 1, base, base + timedelta(minutes=5), 48.1, 2.1), - await _create_alert( - detection_session, 1, base + timedelta(days=1), base + timedelta(days=1, minutes=5), 48.2, 2.2 - ), - await _create_alert( - detection_session, 1, base + timedelta(days=2), base + timedelta(days=2, minutes=5), 48.3, 2.3 - ), - ] +# ───────────────────────────────────────────────────────────────────────────── +# Unit tests for _iter_alerts_csv: pure serializer behavior, no DB / HTTP / auth +# ───────────────────────────────────────────────────────────────────────────── - auth = pytest.get_token( - pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] + +_UNIT_BASE_DT = datetime(2026, 4, 10, 12, 0, 0) + + +def _make_alert( + *, + id_: int = 1, + organization_id: int = 1, + lat: float | None = 48.0, + lon: float | None = 2.0, + started_at: datetime | None = None, + last_seen_at: datetime | None = None, +) -> Alert: + return Alert( + id=id_, + organization_id=organization_id, + lat=lat, + lon=lon, + started_at=started_at or _UNIT_BASE_DT, + last_seen_at=last_seen_at or _UNIT_BASE_DT + timedelta(minutes=5), ) - resp = await async_client.get( - "/alerts/export?from_date=2026-04-10&to_date=2026-04-12", - headers=auth, + + +def _make_sequence( + *, + id_: int = 1, + camera_id: int = 1, + pose_id: int | None = None, + is_wildfire: AnnotationType | None = None, + sequence_azimuth: float | None = 100.0, + started_at: datetime | None = None, + last_seen_at: datetime | None = None, +) -> Sequence: + return Sequence( + id=id_, + camera_id=camera_id, + pose_id=pose_id, + camera_azimuth=100.0, + is_wildfire=is_wildfire, + sequence_azimuth=sequence_azimuth, + cone_angle=1.0, + started_at=started_at or _UNIT_BASE_DT, + last_seen_at=last_seen_at or _UNIT_BASE_DT + timedelta(minutes=5), ) - assert resp.status_code == 200, resp.text - assert resp.headers["content-type"].startswith("text/csv") - assert "attachment" in resp.headers["content-disposition"] - assert "alerts_2026-04-10_2026-04-12.csv" in resp.headers["content-disposition"] - header, data_rows = _parse_csv_body(resp.text) - assert header == ["id", "lat", "lon", "started_at", "last_seen_at"] - assert [int(r[0]) for r in data_rows] == [a.id for a in alerts] - # ordering is ascending by started_at - started_values = [r[3] for r in data_rows] - assert started_values == sorted(started_values) - # spot-check values for the first row - assert float(data_rows[0][1]) == pytest.approx(48.1) - assert float(data_rows[0][2]) == pytest.approx(2.1) - assert data_rows[0][3] == alerts[0].started_at.isoformat() - assert data_rows[0][4] == alerts[0].last_seen_at.isoformat() +def _run_iter( + alerts: List[Alert], + seq_map: Dict[int, List[Sequence]], + camera_names_by_id: Dict[int, str], +) -> Tuple[List[str], List[Dict[str, str]]]: + body = "".join(_iter_alerts_csv(alerts, seq_map, camera_names_by_id)) + return _parse_export_csv(body) -@pytest.mark.asyncio -async def test_alerts_export_window_narrows(async_client: AsyncClient, detection_session: AsyncSession): - base = datetime(2026, 4, 10, 12, 0, 0) - await _create_alert(detection_session, 1, base, base + timedelta(minutes=5)) - a_in = await _create_alert(detection_session, 1, base + timedelta(days=1), base + timedelta(days=1, minutes=5)) - await _create_alert(detection_session, 1, base + timedelta(days=2), base + timedelta(days=2, minutes=5)) - auth = pytest.get_token( - pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] - ) - resp = await async_client.get( - "/alerts/export?from_date=2026-04-11&to_date=2026-04-11", - headers=auth, - ) - assert resp.status_code == 200, resp.text - _, data_rows = _parse_csv_body(resp.text) - returned_ids = {int(r[0]) for r in data_rows} - assert returned_ids == {a_in.id} +def test_iter_alerts_csv_emits_only_header_when_no_alerts(): + header, rows = _run_iter([], {}, {}) + assert header == _ALERT_EXPORT_COLUMNS + assert rows == [] -@pytest.mark.asyncio -async def test_alerts_export_org_isolation(async_client: AsyncClient, detection_session: AsyncSession): - base = datetime(2026, 4, 10, 12, 0, 0) - org1_alert = await _create_alert(detection_session, 1, base, base + timedelta(minutes=5)) - org2_alert = await _create_alert(detection_session, 2, base, base + timedelta(minutes=5)) +def test_iter_alerts_csv_renders_null_coordinates_as_empty(): + alert = _make_alert(lat=None, lon=None) + sequence = _make_sequence() + _, rows = _run_iter([alert], {alert.id: [sequence]}, {sequence.camera_id: "cam-1"}) + assert rows[0]["alert_triangulated_lat"] == "" + assert rows[0]["alert_triangulated_lon"] == "" - # Call as a non-admin user from org 1 - auth = pytest.get_token( - pytest.user_table[1]["id"], pytest.user_table[1]["role"].split(), pytest.user_table[1]["organization_id"] - ) - resp = await async_client.get( - "/alerts/export?from_date=2026-04-10&to_date=2026-04-10", - headers=auth, - ) + +def test_iter_alerts_csv_emits_one_row_per_sequence_sorted_by_started_at(): + alert = _make_alert(id_=10, started_at=_UNIT_BASE_DT, last_seen_at=_UNIT_BASE_DT + timedelta(minutes=30)) + # Provided in non-monotonic order to verify the serializer sorts ASC by sequence.started_at. + sequences = [ + _make_sequence( + id_=20, + started_at=_UNIT_BASE_DT + timedelta(minutes=10), + last_seen_at=_UNIT_BASE_DT + timedelta(minutes=20), + ), + _make_sequence( + id_=30, + started_at=_UNIT_BASE_DT + timedelta(minutes=20), + last_seen_at=_UNIT_BASE_DT + timedelta(minutes=30), + ), + _make_sequence(id_=10, started_at=_UNIT_BASE_DT, last_seen_at=_UNIT_BASE_DT + timedelta(minutes=10)), + ] + _, rows = _run_iter([alert], {alert.id: sequences}, {1: "cam-1"}) + assert [int(r["sequence_id"]) for r in rows] == [10, 20, 30] + # Alert-level cells repeat across rows + assert {r["alert_started_at_date"] for r in rows} == {alert.started_at.date().isoformat()} + assert {r["alert_last_seen_at"] for r in rows} == {alert.last_seen_at.isoformat()} + + +@pytest.mark.parametrize( + ("is_wildfire", "expected_label"), + [ + (AnnotationType.WILDFIRE_SMOKE, "wildfire"), + (AnnotationType.OTHER_SMOKE, "other"), + (AnnotationType.OTHER, "other"), + (None, "unknown"), + ], +) +def test_iter_alerts_csv_wildfire_label_mapping(is_wildfire: AnnotationType | None, expected_label: str): + alert = _make_alert() + sequence = _make_sequence(is_wildfire=is_wildfire) + _, rows = _run_iter([alert], {alert.id: [sequence]}, {sequence.camera_id: "cam-1"}) + assert rows[0]["sequence_label"] == expected_label + + +def test_iter_alerts_csv_resolves_camera_name_per_sequence(): + alert = _make_alert() + seq_a = _make_sequence(id_=1, camera_id=1, started_at=_UNIT_BASE_DT) + seq_b = _make_sequence(id_=2, camera_id=99, started_at=_UNIT_BASE_DT + timedelta(minutes=10)) + _, rows = _run_iter([alert], {alert.id: [seq_a, seq_b]}, {1: "cam-a", 99: "cam-b"}) + cam_by_seq = {int(r["sequence_id"]): r["camera_name"] for r in rows} + assert cam_by_seq == {1: "cam-a", 2: "cam-b"} + + +# ───────────────────────────────────────────────────────────────────────────── +# Integration tests for GET /alerts/export: route wiring, SQL filter, JWT scope +# ───────────────────────────────────────────────────────────────────────────── + + +async def _get_export( + async_client: AsyncClient, auth: Dict[str, str], from_date: str, to_date: str +) -> Tuple[List[str], List[Dict[str, str]]]: + resp = await async_client.get(f"/alerts/export?from_date={from_date}&to_date={to_date}", headers=auth) assert resp.status_code == 200, resp.text - _, data_rows = _parse_csv_body(resp.text) - returned_ids = {int(r[0]) for r in data_rows} - assert org1_alert.id in returned_ids - assert org2_alert.id not in returned_ids + return _parse_export_csv(resp.text) + + +@pytest.fixture +def export_base_dt() -> datetime: + """Anchor datetime for export integration tests; date is stable so query windows stay readable.""" + return datetime(2026, 4, 10, 12, 0, 0) + + +@pytest.fixture +def org1_admin_auth() -> Dict[str, str]: + user = pytest.user_table[0] + return pytest.get_token(user["id"], user["role"].split(), user["organization_id"]) + + +@pytest.fixture +def org1_agent_auth() -> Dict[str, str]: + user = pytest.user_table[1] + return pytest.get_token(user["id"], user["role"].split(), user["organization_id"]) @pytest.mark.asyncio -async def test_alerts_export_empty_range(async_client: AsyncClient, detection_session: AsyncSession): - auth = pytest.get_token( - pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] - ) - resp = await async_client.get( - "/alerts/export?from_date=2099-01-01&to_date=2099-01-31", - headers=auth, - ) +async def test_alerts_export_happy_path( + async_client: AsyncClient, + detection_session: AsyncSession, + export_base_dt: datetime, + org1_admin_auth: Dict[str, str], +): + alerts: List[Alert] = [] + for offset_days, (lat, lon) in enumerate([(48.1, 2.1), (48.2, 2.2), (48.3, 2.3)]): + started = export_base_dt + timedelta(days=offset_days) + alert = await _create_alert(detection_session, 1, started, started + timedelta(minutes=5), lat, lon) + await _attach_sequence(detection_session, alert) + alerts.append(alert) + + resp = await async_client.get("/alerts/export?from_date=2026-04-10&to_date=2026-04-12", headers=org1_admin_auth) assert resp.status_code == 200, resp.text - header, data_rows = _parse_csv_body(resp.text) - assert header == ["id", "lat", "lon", "started_at", "last_seen_at"] - assert data_rows == [] + assert resp.headers["content-type"].startswith("text/csv") + assert "attachment" in resp.headers["content-disposition"] + assert "alerts_2026-04-10_2026-04-12.csv" in resp.headers["content-disposition"] + + _, rows = _parse_export_csv(resp.text) + assert [int(r["alert_id"]) for r in rows] == [a.id for a in alerts] + # ordering is ascending by alert.started_at + started_iso = [f"{r['alert_started_at_date']}T{r['alert_started_at_time']}" for r in rows] + assert started_iso == sorted(started_iso) + # One dict equality covers column set, names, and values in a single pytest diff. + first = rows[0] + assert first == { + "alert_id": str(alerts[0].id), + "alert_started_at_date": alerts[0].started_at.date().isoformat(), + "alert_started_at_time": alerts[0].started_at.time().isoformat(), + "alert_last_seen_at": alerts[0].last_seen_at.isoformat(), + "alert_duration_seconds": str(int((alerts[0].last_seen_at - alerts[0].started_at).total_seconds())), + "alert_triangulated_lat": "48.1", + "alert_triangulated_lon": "2.1", + "organization_id": "1", + "sequence_id": str(first["sequence_id"]), # id auto-generated, just round-trip + "sequence_started_at": alerts[0].started_at.isoformat(), + "sequence_last_seen_at": alerts[0].last_seen_at.isoformat(), + "sequence_triangulated_azimuth": "100.0", + "sequence_label": "unknown", + "pose_id": "", + "camera_id": "1", + "camera_name": "cam-1", + } @pytest.mark.asyncio -async def test_alerts_export_renders_null_coordinates_as_empty( - async_client: AsyncClient, detection_session: AsyncSession +async def test_alerts_export_window_narrows( + async_client: AsyncClient, + detection_session: AsyncSession, + export_base_dt: datetime, + org1_admin_auth: Dict[str, str], ): - base = datetime(2026, 4, 10, 12, 0, 0) - alert = await _create_alert(detection_session, 1, base, base + timedelta(minutes=5), lat=None, lon=None) + for offset_days in range(3): + started = export_base_dt + timedelta(days=offset_days) + alert = await _create_alert(detection_session, 1, started, started + timedelta(minutes=5)) + await _attach_sequence(detection_session, alert) - auth = pytest.get_token( - pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] - ) - resp = await async_client.get( - "/alerts/export?from_date=2026-04-10&to_date=2026-04-10", - headers=auth, - ) - assert resp.status_code == 200, resp.text - _, data_rows = _parse_csv_body(resp.text) - row = next(r for r in data_rows if int(r[0]) == alert.id) - assert row[1] == "" - assert row[2] == "" + _, rows = await _get_export(async_client, org1_admin_auth, "2026-04-11", "2026-04-11") + returned_dates = {r["alert_started_at_date"] for r in rows} + assert returned_dates == {"2026-04-11"} @pytest.mark.asyncio -async def test_alerts_export_invalid_range(async_client: AsyncClient, detection_session: AsyncSession): - auth = pytest.get_token( - pytest.user_table[0]["id"], pytest.user_table[0]["role"].split(), pytest.user_table[0]["organization_id"] - ) - resp = await async_client.get( - "/alerts/export?from_date=2026-04-12&to_date=2026-04-10", - headers=auth, - ) +async def test_alerts_export_org_isolation( + async_client: AsyncClient, + detection_session: AsyncSession, + export_base_dt: datetime, + org1_agent_auth: Dict[str, str], +): + org1_alert = await _create_alert(detection_session, 1, export_base_dt, export_base_dt + timedelta(minutes=5)) + await _attach_sequence(detection_session, org1_alert, camera_id=1) + org2_alert = await _create_alert(detection_session, 2, export_base_dt, export_base_dt + timedelta(minutes=5)) + await _attach_sequence(detection_session, org2_alert, camera_id=2) + + _, rows = await _get_export(async_client, org1_agent_auth, "2026-04-10", "2026-04-10") + returned_ids = {int(r["alert_id"]) for r in rows} + assert org1_alert.id in returned_ids + assert org2_alert.id not in returned_ids + + +@pytest.mark.asyncio +async def test_alerts_export_invalid_range( + async_client: AsyncClient, detection_session: AsyncSession, org1_admin_auth: Dict[str, str] +): + resp = await async_client.get("/alerts/export?from_date=2026-04-12&to_date=2026-04-10", headers=org1_admin_auth) assert resp.status_code == 422, resp.text