diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..dfbefc5 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,92 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + ruff: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.14" + + - name: Install ruff + run: pip install ruff + + - name: Ruff check + run: | + set +e + output=$(ruff check . --statistics 2>&1) + exit_code=$? + echo "## Ruff Lint Report" >> $GITHUB_STEP_SUMMARY + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY + echo "$output" >> $GITHUB_STEP_SUMMARY + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY + exit $exit_code + + - name: Ruff format check + run: | + set +e + output=$(ruff format --check . 2>&1) + exit_code=$? + echo "## Ruff Format Report" >> $GITHUB_STEP_SUMMARY + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY + echo "$output" >> $GITHUB_STEP_SUMMARY + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY + exit $exit_code + + mypy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v6 + + - name: Set up Python + uses: actions/setup-python@v6 + with: + python-version: "3.14" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + pip install -e ".[fastapi]" + + - name: mypy + run: | + set +e + output=$(mypy open_ess 2>&1) + exit_code=$? + echo "## mypy Report" >> $GITHUB_STEP_SUMMARY + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY + echo "$output" >> $GITHUB_STEP_SUMMARY + echo "\`\`\`" >> $GITHUB_STEP_SUMMARY + exit $exit_code + + pytest: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.11", "3.12", "3.13", "3.14"] + + steps: + - uses: actions/checkout@v6 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v6 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run tests + run: pytest diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3a778b1..037e6ac 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,5 +1,14 @@ repos: - - repo: https://github.com/psf/black - rev: 26.3.1 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v6.0.0 hooks: - - id: black + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.15.12 + hooks: + - id: ruff # Linter + args: [ --fix ] + - id: ruff-format # Formatter diff --git a/default.nix b/default.nix index 4927d50..1d0c52c 100644 --- a/default.nix +++ b/default.nix @@ -25,7 +25,8 @@ buildPythonPackage { ]; meta = with lib; { - description = "Open Energy Storage System - Charge/discharge schedule optimizer for day-ahead energy prices."; + description = + "Open Energy Storage System - Charge/discharge schedule optimizer for day-ahead energy prices."; license = licenses.mit; }; } diff --git a/docs/dev/getting started.md b/docs/dev/getting started.md new file mode 100644 index 0000000..e4e6aa0 --- /dev/null +++ b/docs/dev/getting started.md @@ -0,0 +1,41 @@ +To set up a dev environment run; + +```bash +python3 -m venv .venv +source .venv/bin/activate +pip install --upgrade pip +pip install -e . +pip install -e ".[dev]" + +pre-commit install +pre-commit autoupdate +``` + +### pytest + +```bash +# Run all tests: +pytest + +# For a code coverage report: +pytest --cov=metricsqlite --cov-report=term-missing +``` + +### ruff + +```bash +ruff check . # Lint +ruff check . --fix # Lint + auto-fix +ruff format . # Format + +# pyproject.toml sets `output-format = "concise"`. To show more details run; +ruff check --output-format=full . +``` + +### mypy + +```bash +mypy open_ess + +mypy --install-types +``` diff --git a/open_ess/battery_system/__init__.py b/open_ess/battery_system/__init__.py index e27a825..48377e4 100644 --- a/open_ess/battery_system/__init__.py +++ b/open_ess/battery_system/__init__.py @@ -1,2 +1,4 @@ from .battery_system import BatterySystem, VictronBatterySystem from .config import BatterySystemConfig + +__all__ = ["BatterySystem", "BatterySystemConfig", "VictronBatterySystem"] diff --git a/open_ess/battery_system/battery_system.py b/open_ess/battery_system/battery_system.py index dce85b6..4bfc80c 100644 --- a/open_ess/battery_system/battery_system.py +++ b/open_ess/battery_system/battery_system.py @@ -1,8 +1,9 @@ import logging from abc import ABC, abstractmethod -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from open_ess.victron_modbus import VictronClient + from .config import BatterySystemConfig logger = logging.getLogger(__name__) @@ -17,7 +18,7 @@ def config(self) -> BatterySystemConfig: return self._config @property - def name(self) -> str: + def name(self) -> str | None: return self._config.name @property @@ -25,7 +26,7 @@ def name(self) -> str: def id(self) -> str | None: ... @abstractmethod - def set_ess_setpoint(self, power: float, until: datetime | None = None): ... + def set_ess_setpoint(self, power: float, until: datetime | None = None) -> None: ... class VictronBatterySystem(BatterySystem): @@ -40,8 +41,8 @@ def id(self) -> str | None: return None return f"victron/{self._victron_client.serial}" - def set_ess_setpoint(self, power: float, until: datetime | None = None): + def set_ess_setpoint(self, power: float, until: datetime | None = None) -> None: if until is None: - until = datetime.now(tz=timezone.utc) + timedelta(hours=1) + until = datetime.now(tz=UTC) + timedelta(hours=1) logger.info(f"{self.name}: Set setpoint to {power} W") self._victron_client.set_ess_setpoint(power, until) diff --git a/open_ess/battery_system/config.py b/open_ess/battery_system/config.py index 3528397..19060c1 100644 --- a/open_ess/battery_system/config.py +++ b/open_ess/battery_system/config.py @@ -1,6 +1,6 @@ from typing import Annotated, Literal -from pydantic import BaseModel, Field, model_validator, computed_field +from pydantic import BaseModel, Field, computed_field, model_validator from open_ess.victron_modbus import VictronConfig @@ -39,8 +39,8 @@ class BatterySystemConfig(BaseModel): control: Annotated[VictronConfig | MqttControl, Field(discriminator="type")] metrics: MetricsConfig = MetricsConfig() - @computed_field @property + @computed_field def id(self) -> str: if isinstance(self.control, VictronConfig): return f"victron/vebus/{self.control.vebus_id}" @@ -52,7 +52,7 @@ def is_victron(self) -> bool: return isinstance(self.control, VictronConfig) @model_validator(mode="after") - def check_power_limits(self): + def check_power_limits(self) -> "BatterySystemConfig": if not self.monitor_only: if self.max_charge_power_kw is None: raise ValueError( @@ -65,11 +65,11 @@ def check_power_limits(self): return self @model_validator(mode="after") - def set_defaults(self): + def set_defaults(self) -> "BatterySystemConfig": if self.name is None: self.name = self.id - if self.is_victron: + if isinstance(self.control, VictronConfig): vebus_prefix = self.control.vebus_prefix bms_prefix = self.control.battery_prefix diff --git a/open_ess/config.py b/open_ess/config.py index 9be8ce2..66bdc8f 100644 --- a/open_ess/config.py +++ b/open_ess/config.py @@ -3,9 +3,9 @@ import yaml from pydantic import BaseModel +from open_ess.battery_system import BatterySystemConfig from open_ess.database import DatabaseConfig from open_ess.frontend import FrontendConfig -from open_ess.battery_system import BatterySystemConfig from open_ess.pricing import PriceConfig # TODO: Validate config. If a battery defines mqtt control, require mqtt config. diff --git a/open_ess/database/__init__.py b/open_ess/database/__init__.py index 33f4a86..1947c7a 100644 --- a/open_ess/database/__init__.py +++ b/open_ess/database/__init__.py @@ -1,7 +1,7 @@ from .config import DatabaseConfig from .database import Database, DatabaseConnection from .service import DatabaseService -from .util import ms_to_dt, dt_to_ms +from .util import dt_to_ms, ms_to_dt __all__ = [ "Database", diff --git a/open_ess/database/database.py b/open_ess/database/database.py index b83206f..146795f 100644 --- a/open_ess/database/database.py +++ b/open_ess/database/database.py @@ -5,7 +5,7 @@ from .config import DatabaseConfig from .migration_runner import run_migrations -from .util import dt_to_ms, ms_to_dt, base_conditions +from .util import base_conditions, dt_to_ms, ms_to_dt logger = logging.getLogger(__name__) @@ -27,7 +27,7 @@ def config(self) -> DatabaseConfig: def connect(self) -> "DatabaseConnection": return DatabaseConnection(self._config.path) - def run_migrations(self): + def run_migrations(self) -> None: with self.connect() as conn: run_migrations(conn) @@ -38,24 +38,24 @@ def __init__(self, path: Path): self._conn.row_factory = sqlite3.Row # ^ Makes column access by name possible - def close(self): + def close(self) -> None: self._conn.close() - def __enter__(self): + def __enter__(self) -> "DatabaseConnection": return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__(self, exc_type: object, exc_val: object, exc_tb: object) -> None: self.close() - def execute(self, sql: str, parameters=None) -> sqlite3.Cursor: + def execute(self, sql: str, parameters: list | tuple | None = None) -> sqlite3.Cursor: if parameters is None: parameters = [] return self._conn.execute(sql, parameters) - def commit(self): - return self._conn.commit() + def commit(self) -> None: + self._conn.commit() - def vacuum(self): + def vacuum(self) -> None: self._conn.execute("PRAGMA incremental_vacuum") def _get_labels( @@ -70,10 +70,7 @@ def _get_labels( conditions.append(f"{timestamp_name} < ?") params.append(dt_to_ms(end)) - if conditions: - where_clause = "WHERE " + " AND ".join(conditions) - else: - where_clause = "" + where_clause = "WHERE " + " AND ".join(conditions) if conditions else "" query = f""" SELECT DISTINCT label FROM {table_name} @@ -86,7 +83,7 @@ def _get_labels( # Power # ------------------------------------------------------------------------- - def insert_power(self, label: str, timestamp: datetime, power: float): + def insert_power(self, label: str, timestamp: datetime, power: float | None) -> None: if power is None: return self._conn.execute( @@ -111,7 +108,7 @@ def get_power( if bucket_seconds is not None: bucket_ms = round(bucket_seconds * 1000) select_clause = "(start_time / ?) * ? as bucket, AVG(value) as avg_value" - params = [bucket_ms, bucket_ms] + params + params = [bucket_ms, bucket_ms, *params] group_by = "GROUP BY bucket" order_by = "bucket" else: @@ -141,7 +138,7 @@ def get_power_labels(self, start: datetime | None = None, end: datetime | None = def get_all_power( self, start: datetime, end: datetime | None = None, bucket_seconds: float | None = None - ) -> dict[str, list[tuple[datetime, int]]]: + ) -> dict[str, list[tuple[datetime, float]]]: power_series = {} for label in self.get_power_labels(start, end): power_series[label] = self.get_power(label, start, end, bucket_seconds) @@ -223,8 +220,10 @@ def insert_energy( self, label: str, timestamp: datetime, - energy: float, - ): + energy: float | None, + ) -> None: + if energy is None: + return self._conn.execute( """ INSERT INTO energy (label, timestamp, value) @@ -264,7 +263,12 @@ def get_energy( return result def get_energy_aggregated( - self, label, aggregation_seconds: float, start: datetime | None, end: datetime | None, center_buckets=False + self, + label: str, + aggregation_seconds: float, + start: datetime | None, + end: datetime | None, + center_buckets: bool = False, ) -> list[tuple[datetime, float]]: if start: start -= timedelta(seconds=aggregation_seconds) @@ -298,7 +302,7 @@ def get_energy_aggregated( GROUP BY bucket ORDER BY bucket """ - cursor = self._conn.execute(query, [agg_ms, agg_ms] + params) + cursor = self._conn.execute(query, [agg_ms, agg_ms, *params]) center_offset = agg_ms // 2 if center_buckets else 0 return [(ms_to_dt(r[0] + center_offset), round(r[1], 3)) for r in cursor.fetchall()] @@ -339,7 +343,7 @@ def integrate_power( # Day-ahead prices # ------------------------------------------------------------------------- - def insert_price(self, area: str, start_time: datetime, end_time: datetime, price: float): + def insert_price(self, area: str, start_time: datetime, end_time: datetime, price: float) -> None: self._conn.execute( """ INSERT INTO day_ahead_prices (area, start_time, end_time, price) @@ -351,7 +355,7 @@ def insert_price(self, area: str, start_time: datetime, end_time: datetime, pric ) self._conn.commit() - def insert_prices(self, area: str, prices: list[tuple[datetime, datetime, float]]): + def insert_prices(self, area: str, prices: list[tuple[datetime, datetime, float]]) -> None: self._conn.executemany( """ INSERT INTO day_ahead_prices (area, start_time, end_time, price) @@ -365,7 +369,7 @@ def insert_prices(self, area: str, prices: list[tuple[datetime, datetime, float] logger.debug(f"Inserted {len(prices)} price records") def get_prices( - self, area: str, start: datetime, end: datetime = None, aggregate_minutes: float = None + self, area: str, start: datetime, end: datetime | None = None, aggregate_minutes: float | None = None ) -> list[tuple[datetime, float]]: conditions, params = base_conditions(area, start, end, label_name="area", timestamp_name="start_time") @@ -375,7 +379,7 @@ def get_prices( select_clause = f"(start_time / ?) * ? as {timestamp_column}, AVG(price) as {value_column}" group_by = f"GROUP BY {timestamp_column}" agg_ms = round(aggregate_minutes * 60000) - params = [agg_ms, agg_ms] + params + params = [agg_ms, agg_ms, *params] else: timestamp_column = "start_time" group_by = "" @@ -428,7 +432,7 @@ def get_latest_price_time(self, area: str) -> datetime | None: # Battery SOC # ------------------------------------------------------------------------- - def insert_soc(self, label: str, timestamp: datetime, soc: int): + def insert_soc(self, label: str, timestamp: datetime, soc: int) -> None: # TODO: also insert if last update was more than 5 minutes ago self._conn.execute( """ @@ -443,7 +447,7 @@ def insert_soc(self, label: str, timestamp: datetime, soc: int): ) self._conn.commit() - def get_battery_soc(self, label: str, start: datetime, end: datetime) -> list[tuple[datetime, int]]: + def get_battery_soc(self, label: str, start: datetime, end: datetime) -> list[tuple[datetime, float]]: if isinstance(label, list): label = label[0] cursor = self._conn.execute( diff --git a/open_ess/database/migration_runner.py b/open_ess/database/migration_runner.py index 0636435..406de11 100644 --- a/open_ess/database/migration_runner.py +++ b/open_ess/database/migration_runner.py @@ -1,8 +1,8 @@ import logging +from datetime import UTC, datetime from importlib import import_module from pathlib import Path from typing import TYPE_CHECKING -from datetime import datetime, timezone if TYPE_CHECKING: from .database import DatabaseConnection @@ -31,7 +31,7 @@ def get_migrations() -> list[tuple[int, str]]: return sorted(migrations) -def run_migration(version: int, module_name: str, conn) -> None: +def run_migration(version: int, module_name: str, conn: "DatabaseConnection") -> None: """Run a single migration. Args: @@ -43,7 +43,7 @@ def run_migration(version: int, module_name: str, conn) -> None: module.upgrade(conn) -def run_migrations(conn: "DatabaseConnection"): +def run_migrations(conn: "DatabaseConnection") -> None: conn.execute(""" CREATE TABLE IF NOT EXISTS schema_version ( version INTEGER PRIMARY KEY, @@ -62,7 +62,7 @@ def run_migrations(conn: "DatabaseConnection"): run_migration(version, module_name, conn) conn.execute( "INSERT INTO schema_version (version, applied_at) VALUES (?, ?)", - (version, datetime.now(timezone.utc)), + (version, datetime.now(UTC)), ) conn.commit() logger.info(f"Migration {version} complete") diff --git a/open_ess/database/service.py b/open_ess/database/service.py index 29a755f..b88ee5b 100644 --- a/open_ess/database/service.py +++ b/open_ess/database/service.py @@ -1,7 +1,8 @@ import logging -from datetime import datetime, timezone, timedelta +from datetime import UTC, datetime, timedelta from open_ess.service import Service + from .database import Database, DatabaseConnection logger = logging.getLogger(__name__) @@ -14,23 +15,25 @@ def __init__(self, database: Database): self._config = database.config self._db_conn: DatabaseConnection | None = None - def on_start(self): + def on_start(self) -> None: self._db_conn = self._database.connect() logger.info("DatabaseService started") - def tick(self): + def tick(self) -> None: self._run_compression() - def _run_compression(self): + def _run_compression(self) -> None: + if self._db_conn is None: + return None if self._config.compression.enable: - n_samples, n_buckets = self._db_conn.compress_power( - datetime.now(timezone.utc), self._config.compression.bucket_seconds + n_samples, _n_buckets = self._db_conn.compress_power( + datetime.now(UTC), self._config.compression.bucket_seconds ) if n_samples > 0: self._db_conn.vacuum() - def wait_until_next(self): - now = datetime.now(timezone.utc) + def wait_until_next(self) -> None: + now = datetime.now(UTC) next_run = now.replace(second=0, microsecond=0) + timedelta(minutes=1, seconds=10) # ^ Run next compression 10 seconds after a new minute starts. This ensures that all new metrics # have been written to the database. diff --git a/open_ess/database/util.py b/open_ess/database/util.py index 85bbe69..ac8d692 100644 --- a/open_ess/database/util.py +++ b/open_ess/database/util.py @@ -1,4 +1,4 @@ -from datetime import datetime, timezone +from datetime import UTC, datetime def dt_to_ms(dt: datetime) -> int: @@ -8,7 +8,7 @@ def dt_to_ms(dt: datetime) -> int: def ms_to_dt(ms: int) -> datetime: """Unix milliseconds to UTC datetime.""" - return datetime.fromtimestamp(ms / 1000, tz=timezone.utc) + return datetime.fromtimestamp(ms / 1000, tz=UTC) def base_conditions( diff --git a/open_ess/frontend/__init__.py b/open_ess/frontend/__init__.py index 2f44f17..9212fc3 100644 --- a/open_ess/frontend/__init__.py +++ b/open_ess/frontend/__init__.py @@ -1,5 +1,5 @@ from .app import create_app from .config import FrontendConfig -from .dependencies import init_dependencies, close_dependencies +from .dependencies import close_dependencies, init_dependencies __all__ = ["FrontendConfig", "init_dependencies", "create_app", "close_dependencies"] diff --git a/open_ess/frontend/cli.py b/open_ess/frontend/cli.py index 0d6e7ad..49d4593 100644 --- a/open_ess/frontend/cli.py +++ b/open_ess/frontend/cli.py @@ -4,21 +4,21 @@ from open_ess.config import Config from open_ess.frontend.app import create_app -from open_ess.frontend.dependencies import init_dependencies, close_dependencies -from open_ess.util import setup_logging, parse_args +from open_ess.frontend.dependencies import close_dependencies +from open_ess.util import parse_args, setup_logging setup_logging() logger = logging.getLogger(__name__) -def main(): +def main() -> None: args = parse_args("Open Energy Storage System web dashboard") config = Config.from_file(args.config) if not config.frontend.enable: logger.info("Frontend is not enabled. Exiting...") - init_dependencies(config.database, config.prices) + # TODO: init_dependencies(config.database, config.prices, []) logger.info(f"Starting web server on http://{config.frontend.host}:{config.frontend.port}") @@ -26,7 +26,7 @@ def main(): app = create_app() uvicorn.run( app, - host=config.frontend.host, + host=config.frontend.host, # type: ignore[arg-type] port=config.frontend.port, log_level="info", ) diff --git a/open_ess/frontend/config.py b/open_ess/frontend/config.py index 7b90eaa..11cabeb 100644 --- a/open_ess/frontend/config.py +++ b/open_ess/frontend/config.py @@ -1,3 +1,5 @@ +from typing import Any + from pydantic import BaseModel, model_validator @@ -10,7 +12,7 @@ class FrontendConfig(BaseModel): @model_validator(mode="before") @classmethod - def set_enable_default(cls, data): + def set_enable_default(cls, data: Any) -> Any: if isinstance(data, dict) and "enable" not in data: data["enable"] = data.get("host") is not None return data diff --git a/open_ess/frontend/dependencies.py b/open_ess/frontend/dependencies.py index da95114..5f2e8c4 100644 --- a/open_ess/frontend/dependencies.py +++ b/open_ess/frontend/dependencies.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING -from open_ess.battery_system import BatterySystemConfig, BatterySystem +from open_ess.battery_system import BatterySystem, BatterySystemConfig from open_ess.database import Database, DatabaseConnection from open_ess.pricing import PriceConfig diff --git a/open_ess/frontend/routes/api.py b/open_ess/frontend/routes/api.py index 7e9f5ce..f844233 100644 --- a/open_ess/frontend/routes/api.py +++ b/open_ess/frontend/routes/api.py @@ -1,14 +1,15 @@ import logging -from datetime import datetime, timedelta, timezone -from enum import Enum +from datetime import UTC, datetime, timedelta +from enum import StrEnum from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel -from open_ess.battery_system import BatterySystemConfig, BatterySystem +from open_ess.battery_system import BatterySystem, BatterySystemConfig from open_ess.database import DatabaseConnection -from open_ess.frontend.dependencies import get_database, get_price_config, get_battery_systems, get_battery_configs +from open_ess.frontend.dependencies import get_battery_configs, get_battery_systems, get_database, get_price_config from open_ess.pricing import PriceConfig + from .util import TimeSeries, data_to_timeseries, find_full_battery_cycles logger = logging.getLogger(__name__) @@ -31,7 +32,7 @@ class HealthResponse(BaseModel): @router.get("/health", response_model=HealthResponse) -async def health_check(db: DatabaseConnection = Depends(get_database)): +async def health_check(db: DatabaseConnection = Depends(get_database)) -> HealthResponse: try: # TODO: cursor = db.execute("SELECT name FROM sqlite_master WHERE type='table'") @@ -39,7 +40,7 @@ async def health_check(db: DatabaseConnection = Depends(get_database)): return HealthResponse(status="ok", database="connected", tables=tables) except Exception as e: logger.exception("Health check failed") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e # ---------------------------- # @@ -60,7 +61,7 @@ class SystemLayoutData(BaseModel): @router.get("/system-layout", response_model=SystemLayoutData) -async def get_system_layout(battery_systems: list[BatterySystem] = Depends(get_battery_systems)): +async def get_system_layout(battery_systems: list[BatterySystem] = Depends(get_battery_systems)) -> SystemLayoutData: return SystemLayoutData( phases=[1, 2, 3], # grid_labels=["L1", "L2", "L3"], @@ -86,8 +87,8 @@ class PowerFlowData(BaseModel): @router.get("/power-flow", response_model=PowerFlowData) async def get_power_flow( db: DatabaseConnection = Depends(get_database), battery_systems: list[BatterySystem] = Depends(get_battery_systems) -): - start = datetime.now(timezone.utc) - timedelta(seconds=10) +) -> PowerFlowData: + start = datetime.now(UTC) - timedelta(seconds=10) grid_power = {} for i in (1, 2, 3): @@ -98,7 +99,7 @@ async def get_power_flow( grid_power[f"L{i}"] = power solar_power = None - result = db.get_power(f"victron/pvinverter/31/power/l1", start=start, bucket_seconds=None) + result = db.get_power("victron/pvinverter/31/power/l1", start=start, bucket_seconds=None) if result: _, solar_power = result[-1] @@ -142,7 +143,7 @@ async def get_power_flow( # ------------------------------- # -class Status(str, Enum): +class Status(StrEnum): OK = "ok" WARNING = "warning" ERROR = "error" @@ -165,7 +166,7 @@ class ServicesStatusResponse(BaseModel): @router.get("/services-status", response_model=ServicesStatusResponse) -async def services_status(db: DatabaseConnection = Depends(get_database)): +async def services_status() -> ServicesStatusResponse: try: return ServicesStatusResponse( database=ServiceStatus(status=Status.OK, messages=[]), @@ -173,16 +174,16 @@ async def services_status(db: DatabaseConnection = Depends(get_database)): ) except Exception as e: logger.exception("Health check failed") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e @router.get("/battery-ids", response_model=list[str]) -async def get_battery_ids(battery_configs: dict[str, BatterySystemConfig] = Depends(get_battery_configs)): +async def get_battery_ids(battery_configs: dict[str, BatterySystemConfig] = Depends(get_battery_configs)) -> list[str]: try: return list(battery_configs.keys()) except Exception as e: logger.exception("Failed to get battery ids") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e # ------------------------ # @@ -220,13 +221,10 @@ async def get_energy_flow_endpoint( bucket_minutes: int = Query(default=60), db: DatabaseConnection = Depends(get_database), battery_configs: dict[str, BatterySystemConfig] = Depends(get_battery_configs), -): +) -> EnergyGraphResponse: try: - if battery_id is None: - battery_config = battery_configs["victron/vebus/228"] # TODO - else: - battery_config = battery_configs[battery_id] - now = datetime.now(timezone.utc) + battery_config = battery_configs["victron/vebus/228"] if battery_id is None else battery_configs[battery_id] + now = datetime.now(UTC) if start is None: start = now - timedelta(hours=24) if end is None: @@ -289,7 +287,7 @@ async def get_energy_flow_endpoint( ) except Exception as e: logger.exception("Failed to get energy flow") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e @router.get("/power-graph", response_model=PowerResponse) @@ -300,13 +298,10 @@ async def get_power_graph( aggregate_minutes: int = Query(default=1), db: DatabaseConnection = Depends(get_database), battery_configs: dict[str, BatterySystemConfig] = Depends(get_battery_configs), -): +) -> PowerResponse: try: - if battery_id is None: - battery_config = battery_configs["victron/vebus/228"] # TODO - else: - battery_config = battery_configs[battery_id] - now = datetime.now(timezone.utc) + battery_config = battery_configs["victron/vebus/228"] if battery_id is None else battery_configs[battery_id] + now = datetime.now(UTC) if start is None: start = now - timedelta(hours=24) if end is None: @@ -331,7 +326,7 @@ async def get_power_graph( return PowerResponse(series={k: data_to_timeseries(v) for k, v in series.items()}) except Exception as e: logger.exception("Failed to get power data") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e class PricePoint(BaseModel): @@ -356,11 +351,11 @@ async def get_price_data( aggregate_minutes: int | None = Query(default=None), db: DatabaseConnection = Depends(get_database), price_config: PriceConfig = Depends(get_price_config), -): +) -> PricesResponse: try: if area is None: area = price_config.area - now = datetime.now(timezone.utc) + now = datetime.now(UTC) if start is None: start = now - timedelta(days=7) if end is None: @@ -386,7 +381,7 @@ async def get_price_data( ) except Exception as e: logger.exception("Failed to get prices") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e class BatteryGraphResponse(BaseModel): @@ -402,9 +397,9 @@ async def get_battery_graph( end: datetime | None = Query(default=None), db: DatabaseConnection = Depends(get_database), battery_configs: dict[str, BatterySystemConfig] = Depends(get_battery_configs), -): +) -> dict[str, BatteryGraphResponse]: try: - now = datetime.now(timezone.utc) + now = datetime.now(UTC) if start is None: start = now - timedelta(hours=48) if end is None: @@ -427,7 +422,7 @@ async def get_battery_graph( return result except Exception as e: logger.exception("Failed to get battery SOC") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e # ---------------# @@ -450,16 +445,15 @@ async def get_efficiency_scatter( limit: int = Query(default=2000), aggregate_minutes: int = Query(default=10), idle_threshold: int = Query(default=5), - balancing_threshold: int = Query(default=100), db: DatabaseConnection = Depends(get_database), -): +) -> list[EfficiencyScatterPoint]: try: ac_in = db.get_power("victron/vebus/228/power/ac_in/l1", bucket_seconds=aggregate_minutes * 60, limit=limit) ac_out = db.get_power("victron/vebus/228/power/ac_out/l1", bucket_seconds=aggregate_minutes * 60, limit=limit) dc = db.get_power("victron/vebus/228/power/battery", bucket_seconds=aggregate_minutes * 60, limit=limit) # dc = db.get_power("victron/battery/225/power/battery", bucket_seconds=aggregate_minutes * 60, limit=limit) - data = {ts: [v_in - v_out, None] for (ts, v_in), (_, v_out) in zip(ac_in, ac_out)} + data = {ts: [v_in - v_out, None] for (ts, v_in), (_, v_out) in zip(ac_in, ac_out, strict=False)} for ts, v in dc: if ts in data: data[ts][1] = v @@ -497,14 +491,14 @@ async def get_efficiency_scatter( return points except Exception as e: logger.exception("Failed to get efficiency scatter data") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e class BatteryCycle(BaseModel): start_time: datetime end_time: datetime duration_hours: float - min_soc: int + min_soc: float ac_energy_in: float | None ac_energy_out: float | None dc_energy_in: float @@ -526,13 +520,10 @@ async def get_battery_cycles( db: DatabaseConnection = Depends(get_database), battery_configs: dict[str, BatterySystemConfig] = Depends(get_battery_configs), price_config: PriceConfig = Depends(get_price_config), -): +) -> list[BatteryCycle]: try: - if battery_id is None: - battery_config = battery_configs["victron/vebus/228"] # TODO - else: - battery_config = battery_configs[battery_id] - now = datetime.now(timezone.utc) + battery_config = battery_configs["victron/vebus/228"] if battery_id is None else battery_configs[battery_id] + now = datetime.now(UTC) if start is None: start = now - timedelta(days=30) if end is None: @@ -623,7 +614,7 @@ async def get_battery_cycles( return cycles except Exception as e: logger.exception("Failed to get battery cycles") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e # -------------------------# @@ -638,9 +629,9 @@ async def get_power( end: datetime | None = Query(default=None), aggregate_minutes: int = Query(default=1), db: DatabaseConnection = Depends(get_database), -): +) -> PowerResponse: try: - now = datetime.now(timezone.utc) + now = datetime.now(UTC) if start is None: start = now - timedelta(hours=24) if end is None: @@ -649,7 +640,7 @@ async def get_power( return PowerResponse(series={k: data_to_timeseries(v) for k, v in series.items()}) except Exception as e: logger.exception("Failed to get debug power flows") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e # TODO: add parameter to select subset of series @@ -659,9 +650,9 @@ async def get_energy( start: datetime | None = Query(default=None), end: datetime | None = Query(default=None), db: DatabaseConnection = Depends(get_database), -): +) -> EnergyResponse: try: - now = datetime.now(timezone.utc) + now = datetime.now(UTC) if start is None: start = now - timedelta(hours=24) if end is None: @@ -677,4 +668,4 @@ async def get_energy( return EnergyResponse(series={k: data_to_timeseries(v) for k, v in series.items()}) except Exception as e: logger.exception("Failed to get debug energy flows") - raise HTTPException(status_code=500, detail=str(e)) + raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/open_ess/frontend/routes/pages.py b/open_ess/frontend/routes/pages.py index be14ebb..c4abfbf 100644 --- a/open_ess/frontend/routes/pages.py +++ b/open_ess/frontend/routes/pages.py @@ -12,35 +12,35 @@ @router.get("/favicon.ico", include_in_schema=False) -async def favicon(): +async def favicon() -> FileResponse: return FileResponse(STATIC_DIR / "images/openess-16x16.png") @router.get("/logo-32x32.png", include_in_schema=False) -async def favicon(): +async def logo_32x32() -> FileResponse: return FileResponse(STATIC_DIR / "images/openess-32x32.png") @router.get("/", response_class=HTMLResponse) -async def dashboard(request: Request): +async def dashboard(request: Request) -> HTMLResponse: return templates.TemplateResponse(request, "dashboard.html") @router.get("/metrics", response_class=HTMLResponse) -async def metrics_page(request: Request): +async def metrics_page(request: Request) -> HTMLResponse: return templates.TemplateResponse(request, "metrics.html") @router.get("/cycles", response_class=HTMLResponse) -async def cycles_page(request: Request): +async def cycles_page(request: Request) -> HTMLResponse: return templates.TemplateResponse(request, "cycles.html") @router.get("/debug", response_class=HTMLResponse) -async def debug_page(request: Request): +async def debug_page(request: Request) -> HTMLResponse: return templates.TemplateResponse(request, "debug.html") @router.get("/settings", response_class=HTMLResponse) -async def settings_page(request: Request): +async def settings_page(request: Request) -> HTMLResponse: return templates.TemplateResponse(request, "settings.html") diff --git a/open_ess/frontend/routes/util.py b/open_ess/frontend/routes/util.py index ae6eb5f..204e038 100644 --- a/open_ess/frontend/routes/util.py +++ b/open_ess/frontend/routes/util.py @@ -1,5 +1,5 @@ +from collections.abc import Iterable from datetime import datetime -from typing import Iterable from pydantic import BaseModel @@ -9,7 +9,7 @@ class TimeSeries(BaseModel): values: list[float] -def data_to_timeseries(data: Iterable[tuple[datetime, float]], rounding: int = None) -> TimeSeries: +def data_to_timeseries(data: Iterable[tuple[datetime, float]], rounding: int | None = None) -> TimeSeries: timestamps = [] values = [] for t, v in data: @@ -22,15 +22,15 @@ def data_to_timeseries(data: Iterable[tuple[datetime, float]], rounding: int = N def find_full_battery_cycles( - battery_soc: list[tuple[datetime, float]], full_threshold=99, min_soc_swing=10 -) -> list[tuple[datetime, datetime, int]]: + battery_soc: list[tuple[datetime, float]], full_threshold: int = 99, min_soc_swing: int = 10 +) -> list[tuple[datetime, datetime, float]]: """Very simple algorithm to find battery cycles from full -> lower -> full. Only the start and end timestamps and min SoC for the cycles are returned. """ cycles = [] soc_start_ts = None soc_swing_reached = False - min_soc = 100 + min_soc: float = 100 for timestamp, soc in battery_soc: min_soc = min(soc, min_soc) if soc >= full_threshold: @@ -95,11 +95,11 @@ def find_battery_cycles( "min_soc": min_soc, } - return ( - find_battery_cycles(rows, start, left_peak_idx, min_soc_swing) - + [cycle] - + find_battery_cycles(rows, right_peak_idx + 1, end, min_soc_swing) - ) + return [ + *find_battery_cycles(rows, start, left_peak_idx, min_soc_swing), + cycle, + *find_battery_cycles(rows, right_peak_idx + 1, end, min_soc_swing), + ] else: return find_battery_cycles(rows, start, left_peak_idx, min_soc_swing) + find_battery_cycles( rows, right_peak_idx + 1, end, min_soc_swing diff --git a/open_ess/main.py b/open_ess/main.py index 8c81b02..97482e6 100644 --- a/open_ess/main.py +++ b/open_ess/main.py @@ -3,21 +3,21 @@ import uvicorn -from open_ess.battery_system import VictronBatterySystem +from open_ess.battery_system import BatterySystem, VictronBatterySystem from open_ess.config import Config from open_ess.database import Database, DatabaseService -from open_ess.frontend import init_dependencies, create_app, close_dependencies +from open_ess.frontend import close_dependencies, create_app, init_dependencies from open_ess.optimizer import OptimizerService from open_ess.pricing import EntsoeService -from open_ess.service import Service, ServiceManager -from open_ess.util import setup_logging, parse_args, EndpointFilter +from open_ess.service import ServiceManager +from open_ess.util import EndpointFilter, parse_args, setup_logging from open_ess.victron_modbus import VictronService setup_logging() logger = logging.getLogger(__name__) -def main(): +def main() -> None: args = parse_args("Open Energy Storage System - optimize charging based on day-ahead prices") config = Config.from_file(args.config) @@ -28,7 +28,7 @@ def main(): service_manager = ServiceManager() service_manager.register_service(DatabaseService(database)) service_manager.register_service(EntsoeService(database, config.prices)) - battery_systems = [] + battery_systems: list[BatterySystem] = [] for battery_config in config.battery_systems: if battery_config.is_victron: victron_service = VictronService(database, battery_config) @@ -45,7 +45,7 @@ def main(): ) # Shutdown handler - def shutdown(signum, frame): + def shutdown(signum: int, frame: object) -> None: logger.info("Shutting down...") service_manager.stop() @@ -65,7 +65,7 @@ def shutdown(signum, frame): app = create_app() uvicorn.run( app, - host=config.frontend.host, + host=config.frontend.host, # type: ignore[arg-type] port=config.frontend.port, log_level="info", ) diff --git a/open_ess/optimizer/optimizer.py b/open_ess/optimizer/optimizer.py index 9b4077a..ce9195b 100644 --- a/open_ess/optimizer/optimizer.py +++ b/open_ess/optimizer/optimizer.py @@ -1,13 +1,15 @@ import logging import os from collections.abc import Callable -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta +from typing import Any +import pyomo.core import pyomo.environ as pyo from pyomo.opt import SolverFactory -from open_ess.database import DatabaseConnection from open_ess.battery_system import BatterySystemConfig +from open_ess.database import DatabaseConnection from open_ess.pricing import PriceConfig logger = logging.getLogger(__name__) @@ -34,35 +36,6 @@ def __init__(self, db: DatabaseConnection, price_config: PriceConfig, battery_co def battery_config(self) -> BatterySystemConfig: return self._battery_config - def _soc_balance_rule(self, model, t): - if t == 0: - prev_soc = model.current_soc - else: - prev_soc = model.soc[t - 1] - - # Energy into battery = charge_power - charger_loss - # Energy out of battery = discharge_power + inverter_loss - net_energy = ( - ((model.charge_power[t] - model.charger_loss[t]) - (model.discharge_power[t] + model.inverter_loss[t])) - * model.duration[t] - / 3600 - ) - soc_change = 100 * net_energy / self._battery_config.capacity_kwh - return model.soc[t] == prev_soc + soc_change - - def _objective_rule(self, model): - total = 0 - for t in model.T: - price = model.market_price[t] - # Cost to charge (grid power = charge + charger_loss) - grid_charge = model.charge_power[t] + model.charger_loss[t] - buy_cost = grid_charge * self._price_config.buy_price(price) - # Revenue from discharge (grid power = discharge - inverter_loss... - # but inverter_loss is drawn from battery, so grid gets discharge_power) - sell_revenue = model.discharge_power[t] * self._price_config.sell_price(price) - total += (buy_cost - sell_revenue) * self._price_config.aggregate_minutes / 60 - return total - def optimize(self) -> list[tuple[datetime, datetime, int, float]]: """Generate optimal charge schedule using mixed-integer linear programming. @@ -73,7 +46,7 @@ def optimize(self) -> list[tuple[datetime, datetime, int, float]]: - T: time variable starting at t=0 and going up to t=len(future_prices)-1. """ # Get hourly prices for the planning horizon - now = datetime.now(timezone.utc) + now = datetime.now(UTC) start_hour = now.replace(minute=0, second=0, microsecond=0) prices = self._database.get_prices( self._price_config.area, @@ -107,6 +80,8 @@ def optimize(self) -> list[tuple[datetime, datetime, int, float]]: return [] # Build piecewise linear breakpoints for loss functions + assert self._battery_config.max_charge_power_kw is not None + assert self._battery_config.max_invert_power_kw is not None charger_bp, charger_loss_vals = build_piecewise_loss_points( self._battery_config.max_charge_power_kw, charger_loss ) @@ -155,11 +130,8 @@ def optimize(self) -> list[tuple[datetime, datetime, int, float]]: ) # SOC dynamics constraint - def soc_balance_rule(model, t): - if t == 0: - prev_soc = current_soc - else: - prev_soc = model.soc[t - 1] + def soc_balance_rule(model: pyomo.core.Model, t: int) -> Any: + prev_soc = current_soc if t == 0 else model.soc[t - 1] # Energy into battery = charge_power - charger_loss # Energy out of battery = discharge_power + inverter_loss @@ -177,7 +149,7 @@ def soc_balance_rule(model, t): # ^ Final SoC should equal starting SOC (energy neutral over horizon) # Objective: minimize cost (buy cost - sell revenue) - def objective_rule(model): + def objective_rule(model: pyomo.core.Model) -> float: total = 0 for t in model.T: price = model.market_price[t] @@ -257,8 +229,8 @@ def predict_next_week( delta = last_time - first_time start_of_week = last_time - timedelta(weeks=delta.days // 7) - weeks = [] - week = [] + weeks: list[list[tuple[datetime, float]]] = [] + week: list[tuple[datetime, float]] = [] for t, p in prices: if t > start_of_week: start_of_week += timedelta(weeks=1) @@ -274,7 +246,7 @@ def predict_next_week( for week in weeks: week_avg = sum(p for _, p in week) / len(week) factor = last_week_avg / week_avg - for i, (t, p) in enumerate(week): + for i, (_t, p) in enumerate(week): next_week[i] = (next_week[i][0], next_week[i][1] + p * factor) for i, (t, p) in enumerate(next_week): next_week[i] = (t, p / len(weeks)) diff --git a/open_ess/optimizer/service.py b/open_ess/optimizer/service.py index 2ffc28c..136bdc3 100644 --- a/open_ess/optimizer/service.py +++ b/open_ess/optimizer/service.py @@ -1,10 +1,11 @@ import logging -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from open_ess.battery_system import BatterySystem from open_ess.database import Database, DatabaseConnection from open_ess.pricing import PriceConfig from open_ess.service import Service + from .optimizer import Optimizer logger = logging.getLogger(__name__) @@ -25,26 +26,29 @@ def __init__( self._db_conn: DatabaseConnection | None = None self._optimizer: Optimizer | None = None - def on_start(self): + def on_start(self) -> None: self._db_conn = self._db.connect() self._optimizer = Optimizer( self._db_conn, price_config=self._price_config, battery_config=self._battery_system.config ) - def tick(self): + def tick(self) -> None: + if self._optimizer is None or self._db_conn is None: + return + logger.debug("Running charge optimizer(s)") schedule = self._optimizer.optimize() if schedule: _, _, power, _ = schedule[0] self._battery_system.set_ess_setpoint(power) - self._db_conn.set_schedule(self._battery_system.id, schedule) + self._db_conn.set_schedule(self._battery_system.id, schedule) # type: ignore[arg-type] logger.debug(f"Updated schedule with {len(schedule)} entries") else: logger.warning("Optimizer returned empty schedule") - def wait_until_next(self): + def wait_until_next(self) -> None: """Wait until the start of the next price bracket.""" - now = datetime.now(timezone.utc) + now = datetime.now(UTC) next_run = now.replace( minute=(now.minute // self._price_config.aggregate_minutes) * self._price_config.aggregate_minutes, second=0, diff --git a/open_ess/pricing/client.py b/open_ess/pricing/client.py index a44827a..edf0031 100644 --- a/open_ess/pricing/client.py +++ b/open_ess/pricing/client.py @@ -1,5 +1,5 @@ import logging -from datetime import datetime, timedelta, timezone +from datetime import UTC, datetime, timedelta from xml.etree import ElementTree as ET from zoneinfo import ZoneInfo @@ -9,6 +9,7 @@ from pandas import DataFrame from open_ess.database import DatabaseConnection + from .areas import AREAS from .config import PriceConfig @@ -78,8 +79,8 @@ def fetch_day_ahead_prices( prices.append((row_start, row_end, price)) return prices - def fetch_missing_prices(self): - now = datetime.now(timezone.utc) + def fetch_missing_prices(self) -> None: + now = datetime.now(UTC) end_of_tomorrow = (now + timedelta(days=2)).replace(hour=0, minute=0, second=0, microsecond=0) latest = self._db.get_latest_price_time(self._config.area) diff --git a/open_ess/pricing/config.py b/open_ess/pricing/config.py index b2522cc..adea1c0 100644 --- a/open_ess/pricing/config.py +++ b/open_ess/pricing/config.py @@ -2,12 +2,12 @@ import logging import os +from collections.abc import Callable from pathlib import Path -from typing import Callable from pydantic import BaseModel, model_validator -from .formula import compile_formula, FormulaError +from .formula import FormulaError, compile_formula logger = logging.getLogger(__name__) @@ -69,12 +69,12 @@ def resolve_and_validate(self) -> "PriceConfig": try: self._buy_fn = compile_formula(self.buy_formula) except FormulaError as e: - raise ValueError(f"Invalid buy_formula: {e}") + raise ValueError(f"Invalid buy_formula: {e}") from e try: self._sell_fn = compile_formula(self.sell_formula) except FormulaError as e: - raise ValueError(f"Invalid sell_formula: {e}") + raise ValueError(f"Invalid sell_formula: {e}") from e return self diff --git a/open_ess/pricing/formula.py b/open_ess/pricing/formula.py index 679c6c3..4d4c114 100644 --- a/open_ess/pricing/formula.py +++ b/open_ess/pricing/formula.py @@ -6,7 +6,8 @@ import ast import operator -from typing import Callable +from collections.abc import Callable +from typing import cast # Allowed binary operators BINARY_OPS = { @@ -53,14 +54,14 @@ def _eval_node(node: ast.AST, price: float) -> float: raise FormulaError(f"Operator {op_type.__name__} not allowed") left = _eval_node(node.left, price) right = _eval_node(node.right, price) - return BINARY_OPS[op_type](left, right) + return cast(float, BINARY_OPS[op_type](left, right)) elif isinstance(node, ast.UnaryOp): - op_type = type(node.op) - if op_type not in UNARY_OPS: - raise FormulaError(f"Unary operator {op_type.__name__} not allowed") + unary_op_type = type(node.op) + if unary_op_type not in UNARY_OPS: + raise FormulaError(f"Unary operator {unary_op_type.__name__} not allowed") operand = _eval_node(node.operand, price) - return UNARY_OPS[op_type](operand) + return cast(float, UNARY_OPS[unary_op_type](operand)) # type: ignore[operator] else: raise FormulaError(f"Expression type {type(node).__name__} not allowed") @@ -89,7 +90,7 @@ def compile_formula(formula: str) -> Callable[[float], float]: try: tree = ast.parse(formula, mode="eval") except SyntaxError as e: - raise FormulaError(f"Invalid formula syntax: {e}") + raise FormulaError(f"Invalid formula syntax: {e}") from e # Validate the tree before returning the evaluator def validate(node: ast.AST) -> None: @@ -97,18 +98,18 @@ def validate(node: ast.AST) -> None: validate(node.body) elif isinstance(node, ast.Constant): if not isinstance(node.value, (int, float)): - raise FormulaError(f"Only numeric constants allowed") + raise FormulaError("Only numeric constants allowed") elif isinstance(node, ast.Name): if node.id not in ("price", "p"): raise FormulaError(f"Unknown variable '{node.id}'") elif isinstance(node, ast.BinOp): if type(node.op) not in BINARY_OPS: - raise FormulaError(f"Operator not allowed") + raise FormulaError("Operator not allowed") validate(node.left) validate(node.right) elif isinstance(node, ast.UnaryOp): if type(node.op) not in UNARY_OPS: - raise FormulaError(f"Unary operator not allowed") + raise FormulaError("Unary operator not allowed") validate(node.operand) else: raise FormulaError(f"Expression type {type(node).__name__} not allowed") diff --git a/open_ess/pricing/service.py b/open_ess/pricing/service.py index 8384d1d..bcb9b10 100644 --- a/open_ess/pricing/service.py +++ b/open_ess/pricing/service.py @@ -2,6 +2,7 @@ from open_ess.database import Database, DatabaseConnection from open_ess.service import Service + from .client import EntsoeClient from .config import PriceConfig @@ -19,20 +20,22 @@ def __init__(self, db: Database, config: PriceConfig): self._client: EntsoeClient | None = None self._db_conn: DatabaseConnection | None = None - def on_start(self): + def on_start(self) -> None: self._db_conn = self._db.connect() self._client = EntsoeClient(self._config, self._db_conn) self._fetch_prices() - def tick(self): + def tick(self) -> None: self._fetch_prices() - def _fetch_prices(self): + def _fetch_prices(self) -> None: + if self._client is None: + return None try: self._client.fetch_missing_prices() except Exception as e: logger.error(f"Failed to fetch ENTSO-E prices: {e}") - def wait_until_next(self): + def wait_until_next(self) -> None: # TODO: run from 14:00 self.wait_seconds(self._check_interval) diff --git a/open_ess/scripts/generate_types.py b/open_ess/scripts/generate_types.py index 9036ad2..b95a9d1 100644 --- a/open_ess/scripts/generate_types.py +++ b/open_ess/scripts/generate_types.py @@ -14,11 +14,18 @@ from enum import Enum from pathlib import Path from types import NoneType, UnionType -from typing import Any, get_args, get_origin +from typing import Any, TypedDict, get_args, get_origin from fastapi.routing import APIRoute from pydantic import BaseModel + +class _ParamInfo(TypedDict): + name: str + ts_type: str + optional: bool + + logger = logging.getLogger(__name__) @@ -71,7 +78,7 @@ def python_type_to_ts(python_type: Any, models: dict[str, type]) -> str: # Fallback if hasattr(python_type, "__name__"): - return python_type.__name__ + return str(python_type.__name__) return "unknown" @@ -99,12 +106,12 @@ def generate_interface_ts(model: type[BaseModel], models: dict[str, type]) -> st return "\n".join(lines) -def collect_models(module) -> tuple[list[type[Enum]], list[type[BaseModel]]]: +def collect_models(module: object) -> tuple[list[type[Enum]], list[type[BaseModel]]]: """Collect all Enum and BaseModel classes from a module.""" enums = [] models = [] - for name, obj in inspect.getmembers(module): + for _name, obj in inspect.getmembers(module): if inspect.isclass(obj): if issubclass(obj, Enum) and obj is not Enum: enums.append(obj) @@ -150,7 +157,7 @@ def generate_api_function(route: APIRoute, models_dict: dict[str, type]) -> str response_type = python_type_to_ts(route.response_model, models_dict) # Extract query parameters from the endpoint function signature - params = [] + params: list[_ParamInfo] = [] endpoint = route.endpoint sig = inspect.signature(endpoint) @@ -183,9 +190,9 @@ def generate_api_function(route: APIRoute, models_dict: dict[str, type]) -> str param_strs = [] # Sort so required params come first params.sort(key=lambda x: x["optional"]) - for param in params: - opt = "?" if param["optional"] else "" - param_strs.append(f'{param["name"]}{opt}: {param["ts_type"]}') + for p in params: + opt = "?" if p["optional"] else "" + param_strs.append(f"{p['name']}{opt}: {p['ts_type']}") params_signature = "params: { " + "; ".join(param_strs) + " }" else: params_signature = "" @@ -196,9 +203,9 @@ def generate_api_function(route: APIRoute, models_dict: dict[str, type]) -> str if params: lines.append("const searchParams = new URLSearchParams();") - for param in params: + for p in params: lines.append( - f"if (params.{param['name']} !== undefined) searchParams.set('{param['name']}', String(params.{param['name']}));" + f"if (params.{p['name']} !== undefined) searchParams.set('{p['name']}', String(params.{p['name']}));" ) lines.append("const query = searchParams.toString() ? `?${searchParams.toString()}` : '';") lines.append(f"const response = await fetch(`/api{path}${{query}}`);") @@ -274,7 +281,7 @@ def generate_types_file(output_path: Path) -> None: print(f"Generated {output_path}") -def main(): +def main() -> None: output_path = Path("open_ess/frontend/src/types.ts") try: generate_types_file(output_path) diff --git a/open_ess/service.py b/open_ess/service.py index ababb51..00ffa0e 100644 --- a/open_ess/service.py +++ b/open_ess/service.py @@ -21,7 +21,7 @@ def running(self) -> bool: def is_ready(self) -> bool: return self._ready - def run(self): + def run(self) -> None: """Thread entry point.""" self._running = True logger.info(f"{self.name} started") @@ -47,20 +47,20 @@ def run(self): logger.info(f"{self.name} stopped") - def on_start(self): + def on_start(self) -> None: """Called once when service starts. Override for initialization.""" pass @abstractmethod - def tick(self): + def tick(self) -> None: """Called repeatedly. Override with service logic.""" pass - def wait_until_next(self): + def wait_until_next(self) -> None: """Wait until next tick. Override for custom timing.""" self._stop_event.wait(timeout=1.0) - def stop(self): + def stop(self) -> None: """Signal the service to stop.""" self._running = False self._stop_event.set() @@ -71,19 +71,19 @@ def wait_seconds(self, seconds: float) -> bool: class ServiceManager: - def __init__(self): + def __init__(self) -> None: self._services: list[Service] = [] self._dependencies: dict[Service, list[Service]] = {} self._running = False - def register_service(self, service: Service, requires: Service | list[Service] = None): + def register_service(self, service: Service, requires: Service | list[Service] | None = None) -> None: self._services.append(service) if requires: if not isinstance(requires, list): requires = [requires] self._dependencies[service] = requires - def start(self): + def start(self) -> None: self._running = True services_to_start = self._services services_on_hold = [] @@ -97,11 +97,11 @@ def start(self): services_on_hold = [] time.sleep(0.1) - def stop(self): + def stop(self) -> None: self._running = False for service in self._services: service.stop() - def wait_for_stop(self): + def wait_for_stop(self) -> None: for service in self._services: service.join() diff --git a/open_ess/util.py b/open_ess/util.py index ef4b0c1..72b34f2 100644 --- a/open_ess/util.py +++ b/open_ess/util.py @@ -1,11 +1,7 @@ import argparse import logging -from datetime import datetime, timedelta, timezone from pathlib import Path - -import matplotlib.pyplot as plt - -from open_ess.database import DatabaseConnection +from typing import ClassVar logger = logging.getLogger(__name__) @@ -13,7 +9,7 @@ class ColoredFormatter(logging.Formatter): RESET = "\033[0m" - LEVEL_COLORS = { + LEVEL_COLORS: ClassVar[dict] = { logging.DEBUG: "\033[36m", # cyan logging.INFO: "\033[32m", # green logging.WARNING: "\033[33m", # yellow @@ -21,7 +17,7 @@ class ColoredFormatter(logging.Formatter): logging.CRITICAL: "\033[1;91m", # bold red } - def format(self, record): + def format(self, record: logging.LogRecord) -> str: timestamp = self.formatTime(record, "%Y-%m-%d %H:%M:%S") msecs = f"{record.msecs:03.0f}" time_str = f"\033[90m{timestamp}.{msecs}{self.RESET}" # grey @@ -48,7 +44,7 @@ def format(self, record): return result -def setup_logging(): +def setup_logging() -> None: handler = logging.StreamHandler() handler.setFormatter(ColoredFormatter()) logging.root.addHandler(handler) @@ -74,44 +70,3 @@ def parse_args(description: str) -> argparse.Namespace: help="Path to config file (YAML)", ) return parser.parse_args() - - -def plot_energy_prices(db: DatabaseConnection, area: str): - now = datetime.now(timezone.utc) - start = now - timedelta(days=28) - end = now + timedelta(days=2) - - prices = db.get_prices(area, start, end) - if not prices: - logger.warning(f"No prices found for {area} between {start} and {end}") - return - - # Group prices by week (Monday-based) - weeks: dict[tuple[int, int], tuple[list[float], list[float]]] = {} - for start_time, _, price in prices: - # Find the Monday of this week - days_since_monday = start_time.weekday() - week_start = (start_time - timedelta(days=days_since_monday)).replace(hour=0, minute=0, second=0, microsecond=0) - iso_year, iso_week, _ = week_start.isocalendar() - week_key = (iso_year, iso_week) - - # Hours since Monday 00:00 - hours_offset = (start_time - week_start).total_seconds() / 3600 - - if week_key not in weeks: - weeks[week_key] = ([], []) - weeks[week_key][0].append(hours_offset) - weeks[week_key][1].append(price) - - plt.figure(figsize=(12, 6)) - for (year, week), (hours, values) in sorted(weeks.items()): - plt.step(hours, values, where="post", label=f"{year} W{week}") - - day_names = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"] - plt.xticks(ticks=[i * 24 for i in range(7)], labels=day_names) - plt.ylabel("Price (EUR/MWh)") - plt.title(f"Day-Ahead Energy Prices - {area}") - plt.legend() - plt.grid(True, alpha=0.3) - plt.tight_layout() - plt.show() diff --git a/open_ess/victron_modbus/__init__.py b/open_ess/victron_modbus/__init__.py index 5c01048..a65f9f4 100644 --- a/open_ess/victron_modbus/__init__.py +++ b/open_ess/victron_modbus/__init__.py @@ -1,6 +1,6 @@ from .client import VictronClient from .config import VictronConfig -from .registers import Register, Battery, DataType, GridMeter, SolarInverter, System, VEBus +from .registers import Battery, DataType, GridMeter, Register, SolarInverter, System, VEBus from .service import VictronService __all__ = [ diff --git a/open_ess/victron_modbus/client.py b/open_ess/victron_modbus/client.py index b91b8ef..bfaf303 100644 --- a/open_ess/victron_modbus/client.py +++ b/open_ess/victron_modbus/client.py @@ -1,11 +1,13 @@ import logging -from datetime import datetime, timezone +from datetime import UTC, datetime from threading import Lock from typing import TYPE_CHECKING from open_ess.database import Database, DatabaseConnection + +from .config import VictronConfig from .modbus_client import VictronModbusClient -from .registers import Register, System, VEBus, GridMeter, Battery, SolarInverter +from .registers import Battery, GridMeter, Register, SolarInverter, System, VEBus if TYPE_CHECKING: from open_ess.battery_system import BatterySystemConfig @@ -13,16 +15,25 @@ logger = logging.getLogger(__name__) +def _get_float(values: dict[Register, float | bytes | None], key: Register) -> float | None: + """Extract a float value from read_many results, filtering out bytes and None.""" + val = values.get(key) + return val if isinstance(val, (int, float)) else None + + class VictronClient: def __init__(self, database: Database, config: "BatterySystemConfig"): + if not isinstance(config.control, VictronConfig): + raise TypeError(f"VictronClient requires VictronConfig, got {type(config.control).__name__}") self._db = database self._config = config - self._client = VictronModbusClient(config.control) + self._control: VictronConfig = config.control + self._client = VictronModbusClient(self._control) self._db_conn: DatabaseConnection | None = None self._serial: str | None = None - self._setpoint: float | None = None # In Watt + self._setpoint: float = 0.0 # In Watt self._setpoint_expiration: datetime | None = None self._lock = Lock() @@ -32,7 +43,9 @@ def initialize(self) -> bool: if not self.connect(): return False - self._serial = self.read(self.system_id, System.SERIAL).decode("utf-8") + serial_bytes = self.read(self.system_id, System.SERIAL) + if isinstance(serial_bytes, bytes): + self._serial = serial_bytes.decode("utf-8") # Enable ESS mode 3 (external control) if not self._config.monitor_only: @@ -50,34 +63,37 @@ def serial(self) -> str | None: @property def system_id(self) -> int: - return self._config.control.system_id + return self._control.system_id @property def vebus_id(self) -> int: - return self._config.control.vebus_id + return self._control.vebus_id @property def battery_id(self) -> int | None: - return self._config.control.battery_id + return self._control.battery_id @property def grid_id(self) -> int | None: - return self._config.control.grid_id + return self._control.grid_id @property def pvinverter_id(self) -> int | None: - return self._config.control.pvinverter_id + return self._control.pvinverter_id @property def need_mode_3(self) -> bool: return not self._config.monitor_only - def set_ess_setpoint(self, power: float, until: datetime): + def set_ess_setpoint(self, power: float, until: datetime) -> None: with self._lock: self._setpoint = power self._setpoint_expiration = until - def write_setpoints(self): + def write_setpoints(self) -> None: + if self._db_conn is None: + return + if self._config.monitor_only: return @@ -85,20 +101,20 @@ def write_setpoints(self): if ess_mode != 3: raise ValueError("Someone disabled ESS mode 3! Is VRM still managing the system?") - now = datetime.now(tz=timezone.utc) + now = datetime.now(tz=UTC) with self._lock: if self._setpoint_expiration is None or now >= self._setpoint_expiration: - self._setpoint = None + self._setpoint = 0.0 self._setpoint_expiration = None if self._setpoint is None: return idle_threshold = self._config.idle_threshold_w / 1000 - if self._db_conn.get_current_soc() >= 99 and self._setpoint >= -idle_threshold: + if (self._db_conn.get_current_soc() or 0) >= 99 and self._setpoint >= -idle_threshold: # Keep putting power into the battery to allow balancing of the cells by the BMS. # TODO: implement balancing limits? - self.write(self.vebus_id, VEBus.ESS_SETPOINT_L1, int(self._config.max_charge_power_kw * 1000)) + self.write(self.vebus_id, VEBus.ESS_SETPOINT_L1, int((self._config.max_charge_power_kw or 0) * 1000)) self.write(self.vebus_id, VEBus.ESS_DISABLE_CHARGE, 0) else: if abs(self._setpoint) >= idle_threshold: @@ -107,20 +123,22 @@ def write_setpoints(self): self.write(self.vebus_id, VEBus.ESS_DISABLE_FEEDBACK, 0) else: self.write(self.vebus_id, VEBus.ESS_SETPOINT_L1, 0) - if self._config.control.disable_charger_when_idle: + if self._control.disable_charger_when_idle: self.write(self.vebus_id, VEBus.ESS_DISABLE_CHARGE, 1) - if self._config.control.disable_inverter_when_idle: + if self._control.disable_inverter_when_idle: self.write(self.vebus_id, VEBus.ESS_DISABLE_FEEDBACK, 1) def collect_and_store_measurements(self) -> None: - timestamp = datetime.now(timezone.utc) + if self._db_conn is None: + return + timestamp = datetime.now(UTC) # Read System registers system_regs = [System.GRID_L1, System.GRID_L2, System.GRID_L3] system_values = self.read_many(self.system_id, system_regs) - self._db_conn.insert_power("grid/power/l1", timestamp, system_values.get(System.GRID_L1)) - self._db_conn.insert_power("grid/power/l2", timestamp, system_values.get(System.GRID_L2)) - self._db_conn.insert_power("grid/power/l3", timestamp, system_values.get(System.GRID_L3)) + self._db_conn.insert_power("grid/power/l1", timestamp, _get_float(system_values, System.GRID_L1)) + self._db_conn.insert_power("grid/power/l2", timestamp, _get_float(system_values, System.GRID_L2)) + self._db_conn.insert_power("grid/power/l3", timestamp, _get_float(system_values, System.GRID_L3)) if self.grid_id: # TODO: check if grid meter delivers data per phase or not @@ -129,10 +147,10 @@ def collect_and_store_measurements(self) -> None: [GridMeter.ENERGY_TO_NET_TOTAL, GridMeter.ENERGY_FROM_NET_TOTAL], ) self._db_conn.insert_energy( - "grid/energy/import/total", timestamp, grid_values.get(GridMeter.ENERGY_FROM_NET_TOTAL) + "grid/energy/import/total", timestamp, _get_float(grid_values, GridMeter.ENERGY_FROM_NET_TOTAL) ) self._db_conn.insert_energy( - "grid/energy/export/total", timestamp, grid_values.get(GridMeter.ENERGY_TO_NET_TOTAL) + "grid/energy/export/total", timestamp, _get_float(grid_values, GridMeter.ENERGY_TO_NET_TOTAL) ) if self.pvinverter_id: @@ -146,12 +164,12 @@ def collect_and_store_measurements(self) -> None: self._db_conn.insert_energy( f"victron/pvinverter/{self.pvinverter_id}/energy/l1", timestamp, - pvinverter_values.get(SolarInverter.ENERGY_L1), + _get_float(pvinverter_values, SolarInverter.ENERGY_L1), ) self._db_conn.insert_power( f"victron/pvinverter/{self.pvinverter_id}/power/l1", timestamp, - pvinverter_values.get(SolarInverter.POWER_L1), + _get_float(pvinverter_values, SolarInverter.POWER_L1), ) # VEBus registers for each device @@ -178,58 +196,59 @@ def collect_and_store_measurements(self) -> None: VEBus.ENERGY_AC_OUT_TO_BATTERY, ] - vebus_prefix = self._config.control.vebus_prefix + vebus_prefix = self._control.vebus_prefix vebus_values = self.read_many(self.vebus_id, vebus_regs) self._db_conn.insert_power( - f"{vebus_prefix}/power/ac_in/l1", timestamp, vebus_values.get(VEBus.AC_INPUT_POWER_L1) + f"{vebus_prefix}/power/ac_in/l1", timestamp, _get_float(vebus_values, VEBus.AC_INPUT_POWER_L1) ) self._db_conn.insert_power( - f"{vebus_prefix}/power/ac_out/l1", timestamp, vebus_values.get(VEBus.AC_OUTPUT_POWER_L1) + f"{vebus_prefix}/power/ac_out/l1", timestamp, _get_float(vebus_values, VEBus.AC_OUTPUT_POWER_L1) ) - soc = vebus_values.get(VEBus.SOC) + soc = _get_float(vebus_values, VEBus.SOC) if soc is not None: self._db_conn.insert_soc(f"{vebus_prefix}/soc", timestamp, int(soc)) - dc_current = vebus_values.get(VEBus.DC_CURRENT) - dc_voltage = vebus_values.get(VEBus.DC_VOLTAGE) - self._db_conn.insert_voltage(f"{vebus_prefix}/voltage/battery", timestamp, dc_voltage) - if dc_current is not None and dc_voltage is not None: - dc_power = dc_current * dc_voltage - self._db_conn.insert_power(f"{vebus_prefix}/power/battery", timestamp, dc_power) + dc_current = _get_float(vebus_values, VEBus.DC_CURRENT) + dc_voltage = _get_float(vebus_values, VEBus.DC_VOLTAGE) + if dc_voltage is not None: + self._db_conn.insert_voltage(f"{vebus_prefix}/voltage/battery", timestamp, dc_voltage) + if dc_current is not None: + dc_power = dc_current * dc_voltage + self._db_conn.insert_power(f"{vebus_prefix}/power/battery", timestamp, dc_power) # Energy flows self._db_conn.insert_energy( f"{vebus_prefix}/energy/ac_in_to_ac_out", timestamp, - vebus_values.get(VEBus.ENERGY_AC_IN1_TO_AC_OUT), + _get_float(vebus_values, VEBus.ENERGY_AC_IN1_TO_AC_OUT), ) self._db_conn.insert_energy( - f"{vebus_prefix}/energy/ac_in_import", timestamp, vebus_values.get(VEBus.ENERGY_AC_IN1_TO_BATTERY) + f"{vebus_prefix}/energy/ac_in_import", timestamp, _get_float(vebus_values, VEBus.ENERGY_AC_IN1_TO_BATTERY) ) - # self._database.insert_energy("", timestamp, vebus_values.get(VEBus.ENERGY_AC_IN2_TO_AC_OUT)) - # self._database.insert_energy("", timestamp, vebus_values.get(VEBus.ENERGY_AC_IN2_TO_BATTERY)) + # self._database.insert_energy("", timestamp, _get_float(vebus_values,VEBus.ENERGY_AC_IN2_TO_AC_OUT)) + # self._database.insert_energy("", timestamp, _get_float(vebus_values,VEBus.ENERGY_AC_IN2_TO_BATTERY)) self._db_conn.insert_energy( f"{vebus_prefix}/energy/ac_out_to_ac_in", timestamp, - vebus_values.get(VEBus.ENERGY_AC_OUT_TO_AC_IN1), + _get_float(vebus_values, VEBus.ENERGY_AC_OUT_TO_AC_IN1), ) - # self._database.insert_energy("", timestamp, vebus_values.get(VEBus.ENERGY_AC_OUT_TO_AC_IN2)) + # self._database.insert_energy("", timestamp, _get_float(vebus_values,VEBus.ENERGY_AC_OUT_TO_AC_IN2)) self._db_conn.insert_energy( - f"{vebus_prefix}/energy/ac_in_export", timestamp, vebus_values.get(VEBus.ENERGY_BATTERY_TO_AC_IN1) + f"{vebus_prefix}/energy/ac_in_export", timestamp, _get_float(vebus_values, VEBus.ENERGY_BATTERY_TO_AC_IN1) ) - # self._database.insert_energy("", timestamp, vebus_values.get(VEBus.ENERGY_BATTERY_TO_AC_IN2)) + # self._database.insert_energy("", timestamp, _get_float(vebus_values,VEBus.ENERGY_BATTERY_TO_AC_IN2)) self._db_conn.insert_energy( - f"{vebus_prefix}/energy/ac_out_export", timestamp, vebus_values.get(VEBus.ENERGY_BATTERY_TO_AC_OUT) + f"{vebus_prefix}/energy/ac_out_export", timestamp, _get_float(vebus_values, VEBus.ENERGY_BATTERY_TO_AC_OUT) ) self._db_conn.insert_energy( - f"{vebus_prefix}/energy/ac_out_import", timestamp, vebus_values.get(VEBus.ENERGY_AC_OUT_TO_BATTERY) + f"{vebus_prefix}/energy/ac_out_import", timestamp, _get_float(vebus_values, VEBus.ENERGY_AC_OUT_TO_BATTERY) ) if self.battery_id is not None: - bms_prefix = self._config.control.battery_prefix + bms_prefix = self._control.battery_prefix bms_values = self.read_many( self.battery_id, @@ -242,9 +261,15 @@ def collect_and_store_measurements(self) -> None: ], ) - self._db_conn.insert_power(f"{bms_prefix}/power/battery", timestamp, bms_values.get(Battery.DC_POWER)) - self._db_conn.insert_voltage(f"{bms_prefix}/voltage/battery", timestamp, bms_values.get(Battery.DC_VOLTAGE)) - self._db_conn.insert_soc(f"{bms_prefix}/soc", timestamp, round(bms_values.get(Battery.SOC))) + self._db_conn.insert_power( + f"{bms_prefix}/power/battery", timestamp, _get_float(bms_values, Battery.DC_POWER) + ) + self._db_conn.insert_voltage( + f"{bms_prefix}/voltage/battery", timestamp, _get_float(bms_values, Battery.DC_VOLTAGE) + ) + bms_soc = _get_float(bms_values, Battery.SOC) + if bms_soc is not None: + self._db_conn.insert_soc(f"{bms_prefix}/soc", timestamp, round(bms_soc)) # --------------------------------# # VictronModbusClient bindings # @@ -259,8 +284,8 @@ def read(self, unit_id: int, register: Register) -> float | bytes | None: def write(self, unit_id: int, register: Register, value: float) -> bool: return self._client.write(unit_id, register, value) - def read_many(self, unit_id: int, registers: list[Register]) -> dict[Register, float | None]: + def read_many(self, unit_id: int, registers: list[Register]) -> dict[Register, float | bytes | None]: return self._client.read_many(unit_id, registers) - def close(self): + def close(self) -> None: self._client.close() diff --git a/open_ess/victron_modbus/modbus_client.py b/open_ess/victron_modbus/modbus_client.py index ac989f4..98945c1 100644 --- a/open_ess/victron_modbus/modbus_client.py +++ b/open_ess/victron_modbus/modbus_client.py @@ -30,9 +30,9 @@ def address(self) -> str: return f"{self.host}:{self.port}" def connect(self) -> bool: - return self._client.connect() + return bool(self._client.connect()) - def close(self): + def close(self) -> None: self._client.close() def read(self, unit_id: int, register: Register) -> float | bytes | None: @@ -138,7 +138,7 @@ def read_many(self, unit_id: int, registers: list[Register]) -> dict[Register, f batches.append(current_batch) # Read each batch - results: dict[Register, float | None] = {} + results: dict[Register, float | bytes | None] = {} for batch in batches: start_addr = batch[0].address diff --git a/open_ess/victron_modbus/registers.py b/open_ess/victron_modbus/registers.py index d8bb67c..e9c2d5e 100644 --- a/open_ess/victron_modbus/registers.py +++ b/open_ess/victron_modbus/registers.py @@ -26,7 +26,7 @@ def __init__(self, register_count: int, signed: bool): self.signed = signed -DataType.STRING = StringType +DataType.STRING = StringType # type: ignore[attr-defined] @dataclass(frozen=True) @@ -67,7 +67,7 @@ def __lt__(self, other: "Register") -> bool: # Aggregated system-level data # ============================================================================= class System: - SERIAL = Register("Serial", 800, DataType.STRING(6)) + SERIAL = Register("Serial", 800, DataType.STRING(6)) # type: ignore[attr-defined] # AC consumption AC_CONSUMPTION_L1 = Register("AC Consumption L1", 817, DataType.UINT16) diff --git a/open_ess/victron_modbus/service.py b/open_ess/victron_modbus/service.py index eca4f99..2dd48c0 100644 --- a/open_ess/victron_modbus/service.py +++ b/open_ess/victron_modbus/service.py @@ -2,8 +2,9 @@ import time from typing import TYPE_CHECKING -from open_ess.database import Database, DatabaseConnection +from open_ess.database import Database from open_ess.service import Service + from .client import VictronClient if TYPE_CHECKING: @@ -24,22 +25,22 @@ def __init__(self, db: Database, config: "BatterySystemConfig"): def client(self) -> VictronClient: return self._client - def on_start(self): + def on_start(self) -> None: if not self._client.initialize(): raise RuntimeError(f"Could not connect to Victron GX at {self._client.address}") logger.info(f"Connected to Victron GX at {self._client.address}") - def tick(self): + def tick(self) -> None: self._client.write_setpoints() self._client.collect_and_store_measurements() - def wait_until_next(self): + def wait_until_next(self) -> None: # Sleep until the start of the next second now = time.time() sleep_duration = 1.0 - (now % 1.0) self._stop_event.wait(timeout=sleep_duration) - def stop(self): + def stop(self) -> None: super().stop() if self._client: self._client.close() diff --git a/pyproject.toml b/pyproject.toml index ded075e..1b94c86 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,8 +11,49 @@ include = ["open_ess*"] [tool.setuptools.package-data] open_ess = ["*.js", "*.html", "*.css"] -[tool.black] +[tool.ruff] line-length = 120 +target-version = "py311" +output-format = "concise" + +[tool.ruff.lint] +select = [ + "E", # pycodestyle errors + "F", # pyflakes (unused imports, undefined names) + "I", # isort (import sorting) + "UP", # pyupgrade (modern Python syntax) + "B", # flake8-bugbear (common bugs) + "SIM", # flake8-simplify + "RUF", # Ruff-specific rules +] +ignore = [ + "E501", # line too long (handled by formatter) + "E741", # ambiguous variable name (such as `l`) + "RUF022", # __all__ sorting - alphabetical is fine +] + +[tool.ruff.lint.isort] +known-first-party = ["open_ess"] + +[tool.ruff.lint.per-file-ignores] +"open_ess/frontend/routes/api.py" = ["B008"] # Query() in defaults is standard FastAPI pattern + +[tool.ruff.format] +quote-style = "double" + +[tool.mypy] +python_version = "3.11" +warn_return_any = true +warn_unused_ignores = true +disallow_untyped_defs = true + +[[tool.mypy.overrides]] +module = ["entsoe", "entsoe.*", "pyomo", "pyomo.*"] +ignore_missing_imports = true + +[[tool.mypy.overrides]] +module = "open_ess.frontend.routes.api" +ignore_errors = true [project] @@ -20,6 +61,7 @@ name = "open-ess" description = "Open Energy Storage System - Charge/discharge schedule optimizer for day-ahead energy prices." version = "0.0.0" authors = [{ name = "David van 't Wout", email = "david@vtwout.com" }] +requires-python = ">=3.11" # Because of entsoe-apy dependencies = [ "entsoe-apy", "fastapi", @@ -34,7 +76,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["black", "pre-commit"] +dev = ["pre-commit", "pytest>=8.0", "pytest-cov", "mypy>=1.10", "ruff>=0.4", "types-PyYAML", "pandas-stubs"] [project.scripts] open-ess = "open_ess.main:main" diff --git a/shell.nix b/shell.nix index a2f8fe4..e8cb981 100644 --- a/shell.nix +++ b/shell.nix @@ -1,11 +1,24 @@ { pkgs ? import { } }: +# Note: ruff is dynamically linked and the version installed by pip won't work on NixOS. +# This can be fixed by adding `programs.nix-ld.enable = true;` to your NixOS config. + let open-ess = pkgs.python3.pkgs.callPackage ./default.nix { }; in pkgs.mkShell { packages = with pkgs; [ - (python3.withPackages (_: open-ess.propagatedBuildInputs)) - cbc # MILP solver for the optimizer + (python3.withPackages (pp: + open-ess.propagatedBuildInputs ++ (with pp; [ + # Dev tools; + pre-commit-hooks + mypy + ruff + # pytest and dependencies; + pytest + pytest-cov + ]))) + cbc # MILP solver for the optimizer esbuild + pre-commit ]; shellHook = '' diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py new file mode 100644 index 0000000..beb531e --- /dev/null +++ b/tests/test_optimizer.py @@ -0,0 +1,21 @@ +from collections.abc import Generator + +import pytest + +from open_ess.optimizer import Optimizer + + +@pytest.fixture +def optimizer() -> Generator[Optimizer, None, None]: + optimizer = Optimizer(None, None, None) # type: ignore[arg-type] + yield optimizer + + +class TestOptimizer: + def test(self) -> None: + """""" + pytest.skip("TODO: Implement test") + + def test_no_data(self) -> None: + """""" + pytest.skip("TODO: Implement test")