From 49577fab66978c0c78aad83b07789238c49e6326 Mon Sep 17 00:00:00 2001 From: Stefan Slivinski Date: Thu, 30 Apr 2026 09:56:10 -0700 Subject: [PATCH 1/3] feat(auth): replace cookie scraping with Auth0 OAuth + DPoP-bound refresh Replaces the manual cookie-extraction auth flow with the same Auth0 Universal Login + DPoP-bound refresh token flow the MyGenerac mobile app uses. Adds an email/password config flow, an in-place reauth step, a configurable scan_interval option, and bumps the API surface to v5 to match the app. - New auth.py: Auth0 Universal Login + DPoP token client - API endpoints bumped to v5 - Email/password config flow with reauth step - Options flow for scan_interval (default 900s) - Coordinator passes config_entry to silence HA 2025+ warning - entity: device_state_attributes -> extra_state_attributes - sensor: _safe_float helper for malformed numeric props - image: guard against missing content-type header - Tests updated end-to-end for new auth and api surfaces Existing entries hit reauth on next poll and prompt for MyGenerac email/password once. Tokens are then persisted in the config entry and refreshed automatically. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- custom_components/generac/__init__.py | 55 +- custom_components/generac/api.py | 149 +++-- custom_components/generac/auth.py | 554 ++++++++++++++++++ custom_components/generac/config_flow.py | 236 +++++--- custom_components/generac/const.py | 16 +- custom_components/generac/coordinator.py | 23 +- custom_components/generac/entity.py | 5 +- custom_components/generac/image.py | 3 +- custom_components/generac/manifest.json | 4 +- custom_components/generac/sensor.py | 33 +- .../generac/translations/en.json | 33 +- tests/test_api.py | 282 +++++---- tests/test_config_flow.py | 164 +++--- tests/test_entity.py | 3 +- tests/test_init.py | 76 ++- tests/test_sensor.py | 23 +- 16 files changed, 1212 insertions(+), 447 deletions(-) create mode 100644 custom_components/generac/auth.py diff --git a/custom_components/generac/__init__.py b/custom_components/generac/__init__.py index f506fe1..fdf76f9 100644 --- a/custom_components/generac/__init__.py +++ b/custom_components/generac/__init__.py @@ -2,17 +2,22 @@ Custom integration to integrate generac with Home Assistant. For more details about this integration, please refer to -https://github.com/binarydev/generac +https://github.com/binarydev/ha-generac """ + import logging from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.exceptions import ConfigEntryNotReady from .api import GeneracApiClient -from .const import CONF_PASSWORD -from .const import CONF_SESSION_COOKIE +from .api import InvalidCredentialsException +from .auth import GeneracAuth +from .auth import InvalidGrantError +from .const import CONF_DPOP_PEM +from .const import CONF_REFRESH_TOKEN from .const import CONF_USERNAME from .const import DOMAIN from .const import PLATFORMS @@ -29,18 +34,43 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): hass.data.setdefault(DOMAIN, {}) _LOGGER.info(STARTUP_MESSAGE) - username = entry.data.get(CONF_USERNAME, "") - password = entry.data.get(CONF_PASSWORD, "") - session_cookie = entry.data.get(CONF_SESSION_COOKIE, "") + refresh_token = entry.data.get(CONF_REFRESH_TOKEN) + pem_str = entry.data.get(CONF_DPOP_PEM) + email = entry.data.get(CONF_USERNAME) + + if not refresh_token or not pem_str: + # Either a fresh v1->v2 migration with stripped data, or + # somehow the credentials were lost. Either way, reauth. + raise ConfigEntryAuthFailed("Missing refresh token or DPoP key") session = await async_client_session(hass) - client = GeneracApiClient(session, username, password, session_cookie) + try: + auth = GeneracAuth.from_storage(session, refresh_token, pem_str, email=email) + except Exception as ex: + _LOGGER.error("Failed to load stored credentials: %s", ex) + raise ConfigEntryAuthFailed("Stored credentials are unreadable") from ex + async def _persist_rt(new_rt: str) -> None: + hass.config_entries.async_update_entry( + entry, data={**entry.data, CONF_REFRESH_TOKEN: new_rt} + ) + + auth.set_refresh_token_persist_callback(_persist_rt) + + client = GeneracApiClient(session, auth) coordinator = GeneracDataUpdateCoordinator(hass, client=client, config_entry=entry) try: await coordinator.async_config_entry_first_refresh() - except Exception as e: - raise ConfigEntryNotReady from e + except InvalidCredentialsException as ex: + raise ConfigEntryAuthFailed(str(ex)) from ex + except InvalidGrantError as ex: + raise ConfigEntryAuthFailed(str(ex)) from ex + except (ConfigEntryAuthFailed, ConfigEntryNotReady): + # Let HA handle these — the coordinator already raises the right + # one. Wrapping them in ConfigEntryNotReady would mask reauth. + raise + except Exception as ex: + raise ConfigEntryNotReady from ex if not coordinator.last_update_success: raise ConfigEntryNotReady @@ -48,7 +78,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): hass.data[DOMAIN][entry.entry_id] = coordinator await hass.config_entries.async_forward_entry_setups(entry, PLATFORMS) - entry.add_update_listener(async_reload_entry) + entry.async_on_unload(entry.add_update_listener(async_reload_entry)) return True @@ -56,8 +86,9 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool: """Handle removal of an entry.""" unloaded = await hass.config_entries.async_unload_platforms(entry, PLATFORMS) if unloaded: - hass.data[DOMAIN].pop(entry.entry_id) - + # Defensive default: if a previous reload already popped the + # coordinator (e.g. mid-reconfigure race), don't KeyError. + hass.data[DOMAIN].pop(entry.entry_id, None) return unloaded diff --git a/custom_components/generac/api.py b/custom_components/generac/api.py index 5d9b61b..d03f1ed 100644 --- a/custom_components/generac/api.py +++ b/custom_components/generac/api.py @@ -1,78 +1,82 @@ -"""Generac API Client.""" +"""Generac MobileLink API client. + +The API itself is plain HTTPS + Bearer auth — no DPoP at this layer. +The Bearer token comes from `GeneracAuth`, which mints fresh access +tokens by exercising a DPoP-bound refresh_token against Auth0. + +API versioning: `/api/v1`, `/api/v2`, and `/api/v5` were all observed +returning identical payloads for the endpoints we use. The iOS app uses +`/api/v5`; we follow suit for futureproofing. +""" + import json import logging import aiohttp from dacite import from_dict -from .const import ALLOWED_DEVICES +from .auth import GeneracAuth, InvalidGrantError, USER_AGENT_API +from .const import ALLOWED_DEVICES, API_BASE from .models import Apparatus from .models import ApparatusDetail from .models import Item -API_BASE = "https://app.mobilelinkgen.com/api" -LOGIN_BASE = "https://generacconnectivity.b2clogin.com/generacconnectivity.onmicrosoft.com/B2C_1A_MobileLink_SignIn" - TIMEOUT = 10 - _LOGGER: logging.Logger = logging.getLogger(__package__) class InvalidCredentialsException(Exception): - pass + """Credentials supplied by the user were rejected.""" class SessionExpiredException(Exception): - pass + """The current access token / refresh token is no longer valid.""" class GeneracApiClient: + """HTTP client for the MobileLink API. + + The client owns the lifetime of the underlying auth handle's access + token but does NOT persist anything; persistence happens at the + ConfigEntry layer in `__init__.py`. + """ + def __init__( self, session: aiohttp.ClientSession, - username: str, - password: str, - session_cookie: str, + auth: GeneracAuth, ) -> None: - """Sample API Client.""" - self._username = username - self._password = password self._session = session - self._session_cookie = session_cookie - self._logged_in = False - self.csrf = "" - # Below is the login fix from https://github.com/bentekkie/ha-generac/pull/140 - self._headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36", - "Accept": "application/json, text/plain, */*", - "Accept-Language": "en-US,en;q=0.9", - "Accept-Encoding": "gzip, deflate, br", - "Connection": "keep-alive", - } + self._auth = auth async def async_get_data(self) -> dict[str, Item] | None: - """Get data from the API.""" - if self._session_cookie: - self._headers["Cookie"] = self._session_cookie - self._logged_in = True - else: - self._logged_in = False - _LOGGER.error("No session cookie provided, cannot login") - raise InvalidCredentialsException("No session cookie provided") + """Top-level entry point used by the coordinator.""" return await self.get_device_data() - async def get_device_data(self): - apparatuses = await self.get_endpoint("/v2/Apparatus/list") + async def get_device_data(self) -> dict[str, Item] | None: + apparatuses = await self.get_endpoint("/Apparatus/list") if apparatuses is None: - _LOGGER.debug("Could not decode apparatuses response") - return None + # Decode failure on /Apparatus/list — surface as a poll + # failure rather than treating it as "fleet has zero devices". + raise IOError("Failed to decode /Apparatus/list response") if not isinstance(apparatuses, list): - _LOGGER.error("Expected list from /v2/Apparatus/list got %s", apparatuses) + raise IOError( + f"Expected list from /Apparatus/list, got {type(apparatuses).__name__}: " + f"{str(apparatuses)[:200]}" + ) data: dict[str, Item] = {} - for apparatus in apparatuses: - apparatus = from_dict(Apparatus, apparatus) + for raw in apparatuses: + try: + apparatus = from_dict(Apparatus, raw) + except Exception as ex: + _LOGGER.warning( + "Skipping malformed apparatus entry: %s (raw=%s)", + ex, + str(raw)[:200], + ) + continue if apparatus.type not in ALLOWED_DEVICES: _LOGGER.debug( "Unknown apparatus type %s %s", apparatus.type, apparatus.name @@ -80,37 +84,62 @@ async def get_device_data(self): continue detail_json = await self.get_endpoint( - f"/v1/Apparatus/details/{apparatus.apparatusId}" + f"/Apparatus/details/{apparatus.apparatusId}" ) if detail_json is None: _LOGGER.debug( - f"Could not decode respose from /v1/Apparatus/details/{apparatus.apparatusId}" + "Could not decode response from /Apparatus/details/%s", + apparatus.apparatusId, + ) + continue + try: + detail = from_dict(ApparatusDetail, detail_json) + except Exception as ex: + _LOGGER.warning( + "Skipping apparatus %s due to malformed detail payload: %s", + apparatus.apparatusId, + ex, ) continue - detail = from_dict(ApparatusDetail, detail_json) data[str(apparatus.apparatusId)] = Item(apparatus, detail) return data async def get_endpoint(self, endpoint: str): try: - headers = {**self._headers} - if self.csrf: - headers["X-Csrf-Token"] = self.csrf - - response = await self._session.get(API_BASE + endpoint, headers=headers) - if response.status == 204: - # no data - return None - - if response.status != 200: - raise SessionExpiredException( - "API returned status code: %s " % response.status - ) + access_token = await self._auth.ensure_access_token() + except InvalidGrantError as ex: + raise InvalidCredentialsException(str(ex)) from ex + + headers = { + "Authorization": f"Bearer {access_token}", + "Accept": "application/json", + "User-Agent": USER_AGENT_API, + } - data = await response.json() - _LOGGER.debug("getEndpoint %s", json.dumps(data)) - return data + url = API_BASE + endpoint + try: + async with self._session.get(url, headers=headers) as response: + if response.status == 204: + return None + + if response.status == 401: + raise SessionExpiredException(f"API returned 401 for {endpoint}") + + if response.status != 200: + body = "" + try: + body = (await response.text())[:200] + except Exception: + pass + raise SessionExpiredException( + f"API returned status code {response.status} for " + f"{endpoint}: {body}" + ) + + data = await response.json() + _LOGGER.debug("getEndpoint %s", json.dumps(data)) + return data except SessionExpiredException: raise except Exception as ex: - raise IOError() from ex + raise IOError(f"GET {url} failed: {type(ex).__name__}: {ex}") from ex diff --git a/custom_components/generac/auth.py b/custom_components/generac/auth.py new file mode 100644 index 0000000..d6c0aa8 --- /dev/null +++ b/custom_components/generac/auth.py @@ -0,0 +1,554 @@ +"""Auth0 + DPoP authentication for the Generac Mobile Link API. + +This module owns the iOS-app-equivalent auth flow: + +* Email + password universal-login against `auth.ecobee.com` (Auth0 tenant + shared with the ecobee mobile apps). +* PKCE + DPoP-bound authorization code exchange. +* Refresh-token rotation off — the same RT is reusable indefinitely as + long as we keep proving possession of the original DPoP key. + +The DPoP private key is therefore part of the credential and must be +persisted alongside the refresh token. We expose the key as a PEM string +so it can live in the ConfigEntry's normal `data` dict. + +Refresh tokens for this client are NOT rotated by Auth0 (verified +empirically with multiple successive refreshes). We never need to +rewrite the ConfigEntry on a successful refresh. +""" + +from __future__ import annotations + +import asyncio +import base64 +import hashlib +import json +import logging +import re +import secrets +import time +import urllib.parse +import uuid +from dataclasses import dataclass +from typing import Awaitable, Callable, Optional + +import aiohttp +from cryptography.hazmat.primitives import hashes, serialization +from cryptography.hazmat.primitives.asymmetric import ec +from cryptography.hazmat.primitives.asymmetric.utils import decode_dss_signature + +_LOGGER = logging.getLogger(__name__) + +AUTH0_DOMAIN = "auth.ecobee.com" +AUTHORIZE_URL = f"https://{AUTH0_DOMAIN}/authorize" +TOKEN_URL = f"https://{AUTH0_DOMAIN}/oauth/token" +RESUME_URL = f"https://{AUTH0_DOMAIN}/authorize/resume" +IDENTIFIER_URL = f"https://{AUTH0_DOMAIN}/u/login/identifier" +PASSWORD_URL = f"https://{AUTH0_DOMAIN}/u/login/password" + +CLIENT_ID = "eyjSuHZLjX3JC1lNmougLa8rjUw666TN" +REDIRECT_URI = ( + "com.generac.mobilelink.auth0://auth.ecobee.com/ios/com.generac.mobilelink/callback" +) +SCOPE = "openid email offline_access invoke:api" +AUDIENCE = "https://prod.ecobee.com/api/v1" + +USER_AGENT_API = "mobilelink/86535 CFNetwork/3860.500.112 Darwin/25.4.0" +USER_AGENT_WEB = ( + "Mozilla/5.0 (iPhone; CPU iPhone OS 26_4 like Mac OS X) " + "AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148" +) + +# Mirrors the Auth0.swift 2.16.2 SDK header captured from the iOS app. +_AUTH0_CLIENT_HEADER = ( + base64.urlsafe_b64encode( + json.dumps( + { + "env": {"swift": "6.x", "iOS": "26.4"}, + "version": "2.16.2", + "name": "Auth0.swift", + }, + separators=(",", ":"), + ).encode() + ) + .rstrip(b"=") + .decode() +) + + +def _b64url(data: bytes) -> str: + return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") + + +def _int_to_b64url(n: int, length: int = 32) -> str: + return _b64url(n.to_bytes(length, "big")) + + +class InvalidGrantError(Exception): + """Raised when the refresh token has been invalidated server-side. + + The caller should map this to `ConfigEntryAuthFailed` so HA prompts + the user to re-authenticate. + """ + + +class InvalidCredentialsError(Exception): + """Raised when the user-supplied email/password is rejected at login.""" + + +@dataclass +class DPoPKey: + """An ES256 keypair plus precomputed JWK + RFC 7638 thumbprint.""" + + private_key: ec.EllipticCurvePrivateKey + jwk: dict + thumbprint: str + + @classmethod + def generate(cls) -> "DPoPKey": + priv = ec.generate_private_key(ec.SECP256R1()) + return cls._from_private(priv) + + @classmethod + def from_pem(cls, pem: bytes) -> "DPoPKey": + priv = serialization.load_pem_private_key(pem, password=None) + if not isinstance(priv, ec.EllipticCurvePrivateKey): + raise ValueError("expected EC private key") + return cls._from_private(priv) + + @classmethod + def from_pem_str(cls, pem: str) -> "DPoPKey": + return cls.from_pem(pem.encode("ascii")) + + @classmethod + def _from_private(cls, priv: ec.EllipticCurvePrivateKey) -> "DPoPKey": + nums = priv.public_key().public_numbers() + jwk = { + "crv": "P-256", + "kty": "EC", + "x": _int_to_b64url(nums.x), + "y": _int_to_b64url(nums.y), + } + canonical = json.dumps(jwk, separators=(",", ":"), sort_keys=True).encode() + thumbprint = _b64url(hashlib.sha256(canonical).digest()) + return cls(private_key=priv, jwk=jwk, thumbprint=thumbprint) + + def to_pem(self) -> bytes: + return self.private_key.private_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PrivateFormat.PKCS8, + encryption_algorithm=serialization.NoEncryption(), + ) + + def to_pem_str(self) -> str: + return self.to_pem().decode("ascii") + + def sign_proof( + self, + htm: str, + htu: str, + nonce: Optional[str] = None, + access_token: Optional[str] = None, + ) -> str: + header = {"alg": "ES256", "typ": "dpop+jwt", "jwk": self.jwk} + payload: dict = { + "jti": str(uuid.uuid4()), + "htm": htm.upper(), + "htu": htu, + "iat": int(time.time()), + } + if nonce is not None: + payload["nonce"] = nonce + if access_token is not None: + ath = hashlib.sha256(access_token.encode("ascii")).digest() + payload["ath"] = _b64url(ath) + + signing_input = ( + _b64url(json.dumps(header, separators=(",", ":")).encode()) + + "." + + _b64url(json.dumps(payload, separators=(",", ":")).encode()) + ).encode("ascii") + + der_sig = self.private_key.sign(signing_input, ec.ECDSA(hashes.SHA256())) + r, s = decode_dss_signature(der_sig) + raw_sig = r.to_bytes(32, "big") + s.to_bytes(32, "big") + return signing_input.decode("ascii") + "." + _b64url(raw_sig) + + +def _make_pkce() -> tuple[str, str]: + verifier = _b64url(secrets.token_bytes(32)) + challenge = _b64url(hashlib.sha256(verifier.encode("ascii")).digest()) + return verifier, challenge + + +# --------------------------------------------------------------------------- +# Login flow (one-shot, runs from the config flow when user submits creds) +# --------------------------------------------------------------------------- + + +async def _authorize( + session: aiohttp.ClientSession, key: DPoPKey, state: str, challenge: str +) -> str: + params = { + "response_type": "code", + "code_challenge": challenge, + "code_challenge_method": "S256", + "redirect_uri": REDIRECT_URI, + "scope": SCOPE, + "audience": AUDIENCE, + "state": state, + "dpop_jkt": key.thumbprint, + "client_id": CLIENT_ID, + "prompt": "login", + "login_hint": "", + "auth0Client": _AUTH0_CLIENT_HEADER, + } + headers = {"User-Agent": USER_AGENT_WEB, "Accept": "text/html,*/*"} + async with session.get( + AUTHORIZE_URL, params=params, headers=headers, allow_redirects=False + ) as resp: + if resp.status not in (302, 303): + raise RuntimeError(f"/authorize: expected redirect, got {resp.status}") + loc = resp.headers["Location"] + set_cookies = resp.headers.getall("Set-Cookie", []) + cookie_names = sorted(c.key for c in session.cookie_jar) + _LOGGER.debug( + "/authorize -> 302 %s set-cookie-count=%d jar-after=%s", + loc[:120], + len(set_cookies), + cookie_names, + ) + qs = urllib.parse.parse_qs(urllib.parse.urlparse(loc).query) + if "state" not in qs: + raise RuntimeError(f"/authorize: no state in redirect: {loc}") + return qs["state"][0] + + +async def _post_login_form( + session: aiohttp.ClientSession, url: str, state: str, form: dict +) -> str: + headers = { + "User-Agent": USER_AGENT_WEB, + "Content-Type": "application/x-www-form-urlencoded", + "Accept": "text/html,*/*", + "Origin": f"https://{AUTH0_DOMAIN}", + "Referer": f"{url}?state={state}", + } + body = urllib.parse.urlencode(form) + async with session.post( + url, + params={"state": state}, + data=body, + headers=headers, + allow_redirects=False, + ) as resp: + if resp.status not in (302, 303): + text = await resp.text() + # Auth0 ULP renders field-level errors as + # class="ulp-input-error-message" data-error-code="" + # Surface the first code so the user sees a meaningful reason + # instead of a bare HTTP 400. + m = re.search(r'data-error-code="([^"]+)"', text) + code = m.group(1) if m else None + _LOGGER.warning( + "POST %s -> %s; auth0 error code=%s", url, resp.status, code + ) + if code: + # Auth0 ULP renders field-level errors (wrong password, + # locked account, etc) with a data-error-code. Surface + # those as InvalidCredentialsError so the config flow + # maps them to "auth" instead of "internal". + if any( + s in code.lower() + for s in ("password", "credential", "user", "lock", "blocked") + ): + raise InvalidCredentialsError(f"login rejected ({code})") + raise RuntimeError(f"POST {url} -> {resp.status}: {code}") + raise RuntimeError(f"POST {url} -> {resp.status}") + return resp.headers["Location"] + + +async def _identifier_step( + session: aiohttp.ClientSession, state: str, email: str +) -> str: + form = { + "state": state, + "username": email, + "js-available": "true", + "webauthn-available": "true", + "is-brave": "false", + "webauthn-platform-available": "true", + "action": "default", + } + loc = await _post_login_form(session, IDENTIFIER_URL, state, form) + parsed = urllib.parse.urlparse(loc) + if not parsed.path.endswith("/u/login/password"): + # Auth0 sends us back to /u/login/identifier when the email is + # not recognized; surface that as bad credentials. + raise InvalidCredentialsError("email not recognized") + return urllib.parse.parse_qs(parsed.query)["state"][0] + + +async def _password_step( + session: aiohttp.ClientSession, state: str, email: str, password: str +) -> str: + form = { + "state": state, + "username": email, + "password": password, + "action": "default", + } + loc = await _post_login_form(session, PASSWORD_URL, state, form) + parsed = urllib.parse.urlparse(loc) + if not parsed.path.endswith("/authorize/resume"): + raise InvalidCredentialsError(f"password rejected: {loc}") + return urllib.parse.parse_qs(parsed.query)["state"][0] + + +async def _resume_to_code(session: aiohttp.ClientSession, resume_state: str) -> str: + headers = {"User-Agent": USER_AGENT_WEB, "Accept": "text/html,*/*"} + async with session.get( + RESUME_URL, + params={"state": resume_state}, + headers=headers, + allow_redirects=False, + ) as resp: + if resp.status not in (302, 303): + raise RuntimeError(f"/authorize/resume -> {resp.status}") + loc = resp.headers["Location"] + if not loc.startswith("com.generac.mobilelink.auth0://"): + raise RuntimeError(f"/authorize/resume: unexpected scheme: {loc}") + qs = urllib.parse.parse_qs(urllib.parse.urlparse(loc).query) + if "code" not in qs: + raise RuntimeError(f"/authorize/resume: no code in redirect: {loc}") + return qs["code"][0] + + +async def _exchange_code( + session: aiohttp.ClientSession, key: DPoPKey, code: str, verifier: str +) -> dict: + body = { + "grant_type": "authorization_code", + "client_id": CLIENT_ID, + "code": code, + "code_verifier": verifier, + "redirect_uri": REDIRECT_URI, + } + + async def _post(nonce: str | None) -> tuple[int, dict, str | None]: + proof = key.sign_proof("POST", TOKEN_URL, nonce=nonce) + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "DPoP": proof, + "User-Agent": USER_AGENT_API, + } + async with session.post(TOKEN_URL, json=body, headers=headers) as resp: + text = await resp.text() + try: + payload = json.loads(text) + except json.JSONDecodeError: + payload = {"raw": text} + return resp.status, payload, resp.headers.get("dpop-nonce") + + status, payload, nonce = await _post(None) + if status == 200: + return payload + if status == 400 and payload.get("error") == "use_dpop_nonce" and nonce: + status, payload, _ = await _post(nonce) + if status == 200: + return payload + raise RuntimeError(f"code exchange failed: {status} {payload}") + + +# --------------------------------------------------------------------------- +# GeneracAuth — the main reusable handle +# --------------------------------------------------------------------------- + + +class GeneracAuth: + """Holds the long-lived credentials (RT + DPoP key) and mints fresh ATs.""" + + # Refresh slightly before expiry so callers always see a fresh token. + _ACCESS_TOKEN_LEEWAY = 60 + + def __init__( + self, + session: aiohttp.ClientSession, + refresh_token: str, + key: DPoPKey, + *, + email: Optional[str] = None, + ) -> None: + self._session = session + self._refresh_token = refresh_token + self._key = key + self._email = email + self._access_token: Optional[str] = None + self._access_token_exp: float = 0.0 + self._dpop_nonce: Optional[str] = None + self._refresh_lock = asyncio.Lock() + self._rt_persist_cb: Optional[Callable[[str], Awaitable[None]]] = None + + def set_refresh_token_persist_callback( + self, cb: Optional[Callable[[str], Awaitable[None]]] + ) -> None: + """Register an async callback invoked when Auth0 rotates the RT. + + The callback receives the new refresh token and is responsible for + persisting it (typically into the ConfigEntry's `data` dict). + """ + self._rt_persist_cb = cb + + @classmethod + async def login( + cls, session: aiohttp.ClientSession, email: str, password: str + ) -> "GeneracAuth": + """Run the full Auth0 universal-login flow and return a ready instance. + + The Auth0 universal-login flow is stateful: /authorize sets a session + cookie that /u/login/identifier and /u/login/password require. Some + shared sessions disable cookie quoting or scrub cookies between calls, + which breaks the handshake. Use a dedicated cookie-jar-backed session + for the login flow only; the long-lived `session` is reused afterward + for refresh-token rotation, which doesn't depend on cookies. + """ + key = DPoPKey.generate() + verifier, challenge = _make_pkce() + state = _b64url(secrets.token_bytes(32)) + + jar = aiohttp.CookieJar(unsafe=True) + async with aiohttp.ClientSession(cookie_jar=jar) as login_session: + login_state = await _authorize(login_session, key, state, challenge) + pw_state = await _identifier_step(login_session, login_state, email) + resume_state = await _password_step( + login_session, pw_state, email, password + ) + code = await _resume_to_code(login_session, resume_state) + tokens = await _exchange_code(login_session, key, code, verifier) + + refresh_token = tokens.get("refresh_token") + if not refresh_token: + raise RuntimeError("login: no refresh_token returned") + + auth = cls(session, refresh_token, key, email=email) + auth._access_token = tokens["access_token"] + auth._access_token_exp = time.time() + int(tokens.get("expires_in", 0)) + _LOGGER.info( + "Login OK: expires_in=%s scope=%s token_type=%s", + tokens.get("expires_in"), + tokens.get("scope"), + tokens.get("token_type"), + ) + return auth + + @classmethod + def from_storage( + cls, + session: aiohttp.ClientSession, + refresh_token: str, + pem_str: str, + *, + email: Optional[str] = None, + ) -> "GeneracAuth": + key = DPoPKey.from_pem_str(pem_str) + return cls(session, refresh_token, key, email=email) + + @property + def refresh_token(self) -> str: + return self._refresh_token + + @property + def pem_str(self) -> str: + return self._key.to_pem_str() + + @property + def email(self) -> Optional[str]: + return self._email + + async def ensure_access_token(self) -> str: + """Return a non-expired access token, refreshing if necessary.""" + if ( + self._access_token + and time.time() < self._access_token_exp - self._ACCESS_TOKEN_LEEWAY + ): + return self._access_token + + async with self._refresh_lock: + # Double-check inside the lock — concurrent callers may have + # already refreshed by the time we acquired it. + if ( + self._access_token + and time.time() < self._access_token_exp - self._ACCESS_TOKEN_LEEWAY + ): + return self._access_token + await self._refresh() + assert self._access_token is not None + return self._access_token + + async def _refresh(self) -> None: + body = { + "grant_type": "refresh_token", + "client_id": CLIENT_ID, + "refresh_token": self._refresh_token, + } + + async def _post(nonce: str | None) -> tuple[int, dict, str | None]: + proof = self._key.sign_proof("POST", TOKEN_URL, nonce=nonce) + headers = { + "Content-Type": "application/json", + "Accept": "application/json", + "DPoP": proof, + "User-Agent": USER_AGENT_API, + } + async with self._session.post( + TOKEN_URL, json=body, headers=headers + ) as resp: + text = await resp.text() + try: + payload = json.loads(text) + except json.JSONDecodeError: + payload = {"raw": text} + return resp.status, payload, resp.headers.get("dpop-nonce") + + status, payload, nonce = await _post(self._dpop_nonce) + if status == 400 and payload.get("error") == "use_dpop_nonce" and nonce: + self._dpop_nonce = nonce + status, payload, nonce2 = await _post(nonce) + if nonce2: + self._dpop_nonce = nonce2 + + if status == 200: + self._access_token = payload["access_token"] + self._access_token_exp = time.time() + int(payload.get("expires_in", 0)) + _LOGGER.info( + "Token refresh OK: expires_in=%s scope=%s token_type=%s", + payload.get("expires_in"), + payload.get("scope"), + payload.get("token_type"), + ) + # Auth0 rotation is OFF for this client, but be defensive: if + # the server ever does rotate, capture the new RT. + new_rt = payload.get("refresh_token") + if new_rt and new_rt != self._refresh_token: + self._refresh_token = new_rt + if self._rt_persist_cb is not None: + try: + await self._rt_persist_cb(new_rt) + _LOGGER.info("Refresh token rotated and persisted") + except Exception: # noqa: BLE001 + _LOGGER.exception( + "Refresh token rotated but persist callback failed; " + "next HA restart may need reauth" + ) + else: + _LOGGER.warning( + "Refresh token rotated but no persist callback registered; " + "next HA restart will need reauth" + ) + return + + if status == 400 and payload.get("error") == "invalid_grant": + raise InvalidGrantError(payload.get("error_description", "invalid_grant")) + + raise RuntimeError(f"token refresh failed: {status} {payload}") diff --git a/custom_components/generac/config_flow.py b/custom_components/generac/config_flow.py index 78a3c49..51f3b24 100644 --- a/custom_components/generac/config_flow.py +++ b/custom_components/generac/config_flow.py @@ -1,19 +1,31 @@ -"""Adds config flow for generac.""" -import json +"""Config flow for the Generac MobileLink integration. + +Auth model (v2): + user submits email + password + -> we run the full Auth0/DPoP login flow inside the flow + -> we persist (email, refresh_token, dpop_pem) in entry.data + -> entry.unique_id = email + +Reauth: when the refresh token gets invalidated server-side, the +coordinator raises ConfigEntryAuthFailed and HA invokes +async_step_reauth here. We collect a fresh password (email is locked to +the entry's unique_id) and overwrite the credentials in place. +""" import logging -import re -import urllib.parse import voluptuous as vol from homeassistant import config_entries from homeassistant.core import callback -from .api import GeneracApiClient -from .api import InvalidCredentialsException +from .auth import GeneracAuth +from .auth import InvalidCredentialsError +from .const import CONF_DPOP_PEM from .const import CONF_OPTIONS from .const import CONF_PASSWORD -from .const import CONF_SESSION_COOKIE +from .const import CONF_REFRESH_TOKEN +from .const import CONF_SCAN_INTERVAL from .const import CONF_USERNAME +from .const import DEFAULT_SCAN_INTERVAL from .const import DOMAIN from .utils import async_client_session @@ -26,134 +38,174 @@ class GeneracFlowHandler(config_entries.ConfigFlow, domain=DOMAIN): VERSION = 1 def __init__(self): - """Initialize.""" - self._errors = {} - - def _extract_email_from_cookie(self, cookie_str): - # Find the MobileLinkClientCookie value using regex - match = re.search(r"MobileLinkClientCookie=([^;]+)", cookie_str) - if not match: - return None - encoded_json = match.group(1) - # URL decode the JSON string - decoded_json = urllib.parse.unquote(encoded_json) - # Parse the JSON to a dict + self._reauth_entry: config_entries.ConfigEntry | None = None + + async def _try_login( + self, email: str, password: str + ) -> tuple[dict | None, str | None]: + """Run the full login flow. Returns (entry_data, error_key).""" try: - data = json.loads(decoded_json) - return data.get("signInName", "") - except Exception: - return None + session = await async_client_session(self.hass) + auth = await GeneracAuth.login(session, email, password) + except InvalidCredentialsError as ex: + _LOGGER.warning("Login rejected by Auth0: %s", ex) + return None, "auth" + except Exception as ex: + _LOGGER.error("Unexpected error during login: %s", ex, exc_info=True) + return None, "internal" + + return ( + { + CONF_USERNAME: email, + CONF_REFRESH_TOKEN: auth.refresh_token, + CONF_DPOP_PEM: auth.pem_str, + }, + None, + ) + + async def async_step_user(self, user_input=None): + """Handle a flow initialized by the user.""" + errors: dict[str, str] = {} + + if user_input is not None: + email = user_input[CONF_USERNAME] + password = user_input[CONF_PASSWORD] + + entry_data, error = await self._try_login(email, password) + if error is None: + await self.async_set_unique_id(email) + self._abort_if_unique_id_configured() + return self.async_create_entry(title=email, data=entry_data) + + errors["base"] = error + + return self.async_show_form( + step_id="user", + data_schema=vol.Schema( + { + vol.Required(CONF_USERNAME): str, + vol.Required(CONF_PASSWORD): str, + } + ), + errors=errors, + ) async def async_step_reconfigure(self, user_input=None): - """Handle reconfiguration.""" - errors = {} + """Handle reconfiguration of an existing entry.""" + errors: dict[str, str] = {} entry = self.hass.config_entries.async_get_entry(self.context["entry_id"]) if user_input is not None: - session_cookie = user_input.get(CONF_SESSION_COOKIE, "") - error = await self._test_credentials( - "", - "", - session_cookie, - ) + email = user_input[CONF_USERNAME] + password = user_input[CONF_PASSWORD] + scan_interval = user_input.get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL) + + entry_data, error = await self._try_login(email, password) if error is None: - return self.async_update_reload_and_abort( + # Persist polling interval into entry.options so the + # coordinator picks it up the same way the OptionsFlow + # does. + new_options = { + **(entry.options or {}), + CONF_SCAN_INTERVAL: int(scan_interval), + } + # Update only — the update listener registered in + # async_setup_entry will reload the entry exactly once. + # Calling async_update_reload_and_abort here would + # double-reload (helper schedules + listener fires) and + # race the unload, surfacing as "failed to unload". + self.hass.config_entries.async_update_entry( entry, - data={**entry.data, **user_input}, - reason="Reconfigure Successful", + data={**entry.data, **entry_data}, + options=new_options, ) + return self.async_abort(reason="Reconfigure Successful") errors["base"] = error + default_email = entry.data.get(CONF_USERNAME, "") if entry else "" + default_scan_interval = ( + (entry.options or {}).get(CONF_SCAN_INTERVAL, DEFAULT_SCAN_INTERVAL) + if entry + else DEFAULT_SCAN_INTERVAL + ) return self.async_show_form( step_id="reconfigure", data_schema=vol.Schema( { + vol.Required(CONF_USERNAME, default=default_email): str, + vol.Required(CONF_PASSWORD): str, vol.Required( - CONF_SESSION_COOKIE, - default=entry.data.get(CONF_SESSION_COOKIE), - ): str, + CONF_SCAN_INTERVAL, default=default_scan_interval + ): int, } ), errors=errors, ) - async def async_step_user(self, user_input=None): - """Handle a flow initialized by the user.""" - self._errors = {} - - # Uncomment the next 2 lines if only a single instance of the integration is allowed: - # if self._async_current_entries(): - # return self.async_abort(reason="single_instance_allowed") + async def async_step_reauth(self, entry_data): + """Trigger a reauth flow when the stored RT has been invalidated.""" + self._reauth_entry = self.hass.config_entries.async_get_entry( + self.context["entry_id"] + ) + return await self.async_step_reauth_confirm() + + async def async_step_reauth_confirm(self, user_input=None): + """Collect fresh credentials to mint new tokens.""" + errors: dict[str, str] = {} + entry = self._reauth_entry + assert entry is not None + # Older config entries may not have stored email under CONF_USERNAME, + # so fall back to the entry title (which we set to the email at + # create time). + default_email = entry.data.get(CONF_USERNAME) or entry.title or "" if user_input is not None: - username = user_input.get(CONF_USERNAME, "") - session_cookie = user_input.get(CONF_SESSION_COOKIE, "") - error = await self._test_credentials( - username, - user_input.get(CONF_PASSWORD, ""), - session_cookie, - ) - if error is None and session_cookie: - unique_id = self._extract_email_from_cookie(session_cookie) or "generac" - - await self.async_set_unique_id(unique_id) - self._abort_if_unique_id_configured() - - return self.async_create_entry(title=unique_id, data=user_input) - else: - self._errors["base"] = error - - return await self._show_config_form(user_input) - - return await self._show_config_form(user_input) - - @staticmethod - @callback - def async_get_options_flow(config_entry): - return GeneracOptionsFlowHandler(config_entry) + password = user_input[CONF_PASSWORD] + # Reauth is bound to the entry's existing email; users who + # need a different account must remove and re-add the + # integration. This prevents silently rebinding the entry + # (and all its entities) to a different Generac account. + email = default_email + + entry_data, error = await self._try_login(email, password) + if error is None: + # Update only — the update listener registered in + # async_setup_entry handles the reload. An explicit + # async_reload here would race the listener-driven + # reload and surface as "failed to unload". + self.hass.config_entries.async_update_entry( + entry, data={**entry.data, **entry_data} + ) + return self.async_abort(reason="reauth_successful") + errors["base"] = error - async def _show_config_form(self, user_input): # pylint: disable=unused-argument - """Show the configuration form to edit location data.""" return self.async_show_form( - step_id="user", + step_id="reauth_confirm", data_schema=vol.Schema( { - # vol.Optional(CONF_USERNAME): str, - # vol.Optional(CONF_PASSWORD): str, - vol.Required(CONF_SESSION_COOKIE): str, + vol.Required(CONF_PASSWORD): str, } ), - errors=self._errors, + description_placeholders={"username": default_email}, + errors=errors, ) - async def _test_credentials(self, username, password, session_cookie): - """Return true if credentials is valid.""" - try: - session = await async_client_session(self.hass) - client = GeneracApiClient(session, username, password, session_cookie) - await client.async_get_data() - return None - except InvalidCredentialsException as e: - _LOGGER.error("Invalid credentials: %s", e) - return "auth" - except Exception as e: - _LOGGER.error("Unexpected error testing credentials: %s", e, exc_info=True) - return "internal" + @staticmethod + @callback + def async_get_options_flow(config_entry): + return GeneracOptionsFlowHandler(config_entry) class GeneracOptionsFlowHandler(config_entries.OptionsFlow): """Config flow options handler for generac.""" def __init__(self, config_entry): - """Initialize HACS options flow.""" self.options = dict(config_entry.options) async def async_step_init(self, user_input=None): # pylint: disable=unused-argument - """Manage the options.""" return await self.async_step_user() async def async_step_user(self, user_input=None): - """Handle a flow initialized by the user.""" if user_input is not None: self.options.update(user_input) return self.async_create_entry(title="", data=self.options) diff --git a/custom_components/generac/const.py b/custom_components/generac/const.py index 0b93990..28b7429 100644 --- a/custom_components/generac/const.py +++ b/custom_components/generac/const.py @@ -1,4 +1,5 @@ """Constants for generac.""" + # Base component constants NAME = "generac" DOMAIN = "generac" @@ -26,7 +27,11 @@ # Defaults DEFAULT_NAME = DOMAIN -DEFAULT_SCAN_INTERVAL = 120 +# 900 s = 15 min. Generac's cloud doesn't push very often (~minutes +# between updates) and the API is rate-limited per-account, so polling +# faster than this provides little benefit and risks hitting their +# throttles. +DEFAULT_SCAN_INTERVAL = 900 # Platforms BINARY_SENSOR = "binary_sensor" @@ -39,7 +44,11 @@ CONF_ENABLED = "enabled" CONF_USERNAME = "username" CONF_PASSWORD = "password" -CONF_SESSION_COOKIE = "session_cookie" +# Credentials: refresh token + DPoP private key (stored as PKCS8 PEM). +# These two values together are the credential — losing either invalidates +# the entry and forces reauth. +CONF_REFRESH_TOKEN = "refresh_token" +CONF_DPOP_PEM = "dpop_pem" CONF_SCAN_INTERVAL = "scan_interval" # Options @@ -62,5 +71,4 @@ """ -API_BASE = "https://app.mobilelinkgen.com/api" -LOGIN_BASE = "https://generacconnectivity.b2clogin.com/generacconnectivity.onmicrosoft.com/B2C_1A_MobileLink_SignIn" +API_BASE = "https://app.mobilelinkgen.com/api/v5" diff --git a/custom_components/generac/coordinator.py b/custom_components/generac/coordinator.py index 8ba2158..a0272fe 100644 --- a/custom_components/generac/coordinator.py +++ b/custom_components/generac/coordinator.py @@ -3,16 +3,19 @@ from homeassistant.config_entries import ConfigEntry from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers.update_coordinator import DataUpdateCoordinator from homeassistant.helpers.update_coordinator import UpdateFailed from .api import GeneracApiClient +from .api import InvalidCredentialsException +from .api import SessionExpiredException +from .auth import InvalidGrantError from .const import CONF_SCAN_INTERVAL from .const import DEFAULT_SCAN_INTERVAL from .const import DOMAIN from .models import Item - _LOGGER: logging.Logger = logging.getLogger(__package__) @@ -35,9 +38,23 @@ def __init__( async def _async_update_data(self): """Update data via library.""" try: - _LOGGER.debug("Refreshing data for generac") + _LOGGER.info("Polling Generac cloud for device data") items = await self.api.async_get_data() self.is_online = items is not None + _LOGGER.info("Generac poll OK: %d device(s)", len(items) if items else 0) return items + except (InvalidCredentialsException, InvalidGrantError) as ex: + # Refresh token / login no longer valid — trigger HA reauth flow. + _LOGGER.warning("Generac auth rejected, requesting reauth: %s", ex) + self.is_online = False + raise ConfigEntryAuthFailed(str(ex)) from ex + except SessionExpiredException as ex: + # 401 / non-200 from the API. Surface it loudly so it shows up + # in HA logs instead of silently freezing entities. + _LOGGER.warning("Generac API session error: %s", ex) + self.is_online = False + raise UpdateFailed(f"API session error: {ex}") from ex except Exception as exception: - raise UpdateFailed() from exception + _LOGGER.exception("Unexpected error refreshing Generac data") + self.is_online = False + raise UpdateFailed(str(exception)) from exception diff --git a/custom_components/generac/entity.py b/custom_components/generac/entity.py index 1072231..0bd7bdf 100644 --- a/custom_components/generac/entity.py +++ b/custom_components/generac/entity.py @@ -47,7 +47,7 @@ def device_info(self): ) @property - def device_state_attributes(self): + def extra_state_attributes(self): """Return the state attributes.""" return { "attribution": ATTRIBUTION, @@ -63,9 +63,6 @@ def available(self): async def async_added_to_hass(self) -> None: """Connect to dispatcher listening for entity data notifications.""" await super().async_added_to_hass() - self.async_on_remove( - self.coordinator.async_add_listener(self.async_write_ha_state) - ) @property def aparatus(self) -> Apparatus: diff --git a/custom_components/generac/image.py b/custom_components/generac/image.py index 87dc0a7..475706e 100644 --- a/custom_components/generac/image.py +++ b/custom_components/generac/image.py @@ -1,4 +1,5 @@ """Image platform for generac.""" + import mimetypes import httpx @@ -62,7 +63,7 @@ async def _fetch_url(self, url: str) -> httpx.Response | None: resp = await super()._fetch_url(url) if ( resp is not None - and "image" not in resp.headers.get("content-type") + and "image" not in (resp.headers.get("content-type") or "") and self.aparatus_detail.heroImageUrl ): guess = mimetypes.guess_type(self.aparatus_detail.heroImageUrl)[0] diff --git a/custom_components/generac/manifest.json b/custom_components/generac/manifest.json index c57fb28..bd9a4b9 100644 --- a/custom_components/generac/manifest.json +++ b/custom_components/generac/manifest.json @@ -7,6 +7,6 @@ "documentation": "https://github.com/binarydev/ha-generac", "iot_class": "cloud_polling", "issue_tracker": "https://github.com/binarydev/ha-generac/issues", - "requirements": ["dacite==1.8.1"], - "version": "0.4.2" + "requirements": ["dacite==1.8.1", "cryptography>=41"], + "version": "0.5.0" } diff --git a/custom_components/generac/sensor.py b/custom_components/generac/sensor.py index a88dd34..422628a 100644 --- a/custom_components/generac/sensor.py +++ b/custom_components/generac/sensor.py @@ -1,4 +1,6 @@ """Sensor platform for generac.""" + +import logging from datetime import datetime from typing import Type @@ -109,6 +111,25 @@ def sensor_name(self, name_label): return f"{DEFAULT_NAME}_{self.device_id}_{name_label}" +_LOGGER = logging.getLogger(__name__) + + +def _safe_float(val, label: str = ""): + """Best-effort float conversion; return None on bad data so the + sensor reports ``unknown`` instead of crashing native_value.""" + if val is None: + return None + if isinstance(val, (int, float)): + return float(val) + try: + return float(val) + except (TypeError, ValueError) as ex: + _LOGGER.debug( + "Could not convert %s sensor value %r to float: %s", label, val, ex + ) + return None + + class StatusSensor(GeneracEntity, SensorEntity): """generac Sensor class.""" @@ -198,9 +219,7 @@ def name(self): def native_value(self): """Return the state of the sensor.""" val = get_prop_value(self.aparatus_detail.properties, 71, 0) - if isinstance(val, str): - val = float(val) - return val + return _safe_float(val, "run_time") class ProtectionTimeSensor(GeneracEntity, SensorEntity): @@ -218,9 +237,7 @@ def name(self): def native_value(self): """Return the state of the sensor.""" val = get_prop_value(self.aparatus_detail.properties, 32, 0) - if isinstance(val, str): - val = float(val) - return val + return _safe_float(val, "protection_time") class ActivationDateSensor(GeneracEntity, SensorEntity): @@ -295,9 +312,7 @@ def name(self): def native_value(self): """Return the state of the sensor.""" val = get_prop_value(self.aparatus_detail.properties, 70, 0) - if isinstance(val, str): - val = float(val) - return val + return _safe_float(val, "battery_voltage") class OutdoorTemperatureSensor(GeneracEntity, SensorEntity): diff --git a/custom_components/generac/translations/en.json b/custom_components/generac/translations/en.json index 8d17c52..d27a54d 100644 --- a/custom_components/generac/translations/en.json +++ b/custom_components/generac/translations/en.json @@ -2,21 +2,38 @@ "config": { "step": { "user": { - "title": "generac", - "description": "If you need help with the configuration have a look here: github.com/binarydev/ha-generac", + "title": "Generac MobileLink", + "description": "Sign in with your Mobile Link account.", "data": { - "username": "Username", + "username": "Email", + "password": "Password" + } + }, + "reconfigure": { + "title": "Generac MobileLink", + "description": "Re-enter your Mobile Link credentials.", + "data": { + "username": "Email", "password": "Password", - "session_cookie": "Session Cookie" + "scan_interval": "Cloud polling interval (seconds)" + } + }, + "reauth_confirm": { + "title": "Reauthenticate Generac MobileLink", + "description": "The session for {username} has expired. Enter your password to refresh it.", + "data": { + "password": "Password" } } }, "error": { - "auth": "Username/Password is wrong.", - "internal": "Internal Error, file an issue at github.com/binarydev/ha-generac." + "auth": "Invalid email or password.", + "internal": "Unexpected error. File an issue at github.com/binarydev/ha-generac." }, "abort": { - "single_instance_allowed": "Only a single instance is allowed." + "single_instance_allowed": "Only a single instance is allowed.", + "already_configured": "This account is already configured.", + "reauth_successful": "Re-authentication successful." } }, "options": { @@ -25,7 +42,7 @@ "data": { "binary_sensor": "Binary sensors enabled", "image": "Image enabled", - "scan_interval": "Polling interval (seconds)", + "scan_interval": "Cloud polling interval (seconds)", "sensor": "Non-binary sensors enabled", "switch": "Switch enabled", "weather": "Weather sensor enabled" diff --git a/tests/test_api.py b/tests/test_api.py index 6052c5e..1adfb01 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,216 +1,200 @@ """Tests for Generac API Client.""" + import asyncio import json from unittest.mock import AsyncMock from unittest.mock import MagicMock -from unittest.mock import patch import pytest from custom_components.generac.api import GeneracApiClient from custom_components.generac.api import InvalidCredentialsException from custom_components.generac.api import SessionExpiredException +from custom_components.generac.auth import InvalidGrantError from custom_components.generac.const import ALLOWED_DEVICES +from custom_components.generac.const import API_BASE from custom_components.generac.const import DEVICE_TYPE_UNKNOWN +def _acm(response): + """Wrap a mock response so it works with `async with session.get(...)`. + + aiohttp's `ClientSession.get()` is a regular function that returns an + object supporting the async context manager protocol, NOT a coroutine. + Using `AsyncMock()` for `session.get` would make `.get(...)` return a + coroutine, which breaks `async with`. Instead, `session.get` is a plain + `MagicMock` whose return value is wrapped here as an async-cm yielding + `response` from `__aenter__`. + """ + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=response) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + @pytest.fixture def mock_session(): """Fixture for aiohttp ClientSession.""" session = MagicMock() - session.get = AsyncMock() + session.get = MagicMock() return session @pytest.fixture -def client(mock_session): - """Fixture for GeneracApiClient.""" - return GeneracApiClient(mock_session, "test_user", "test_pass", "test_cookie") - - -async def test_init(mock_session): - """Test the __init__ method.""" - client = GeneracApiClient(mock_session, "test_user", "test_pass", "test_cookie") - assert client._username == "test_user" - assert client._password == "test_pass" - assert client._session == mock_session - assert client._session_cookie == "test_cookie" - assert client._logged_in is False - assert client.csrf == "" - - -async def test_async_get_data_with_cookie(client, mock_session): - """Test async_get_data with a session cookie.""" - with patch.object( - client, "get_device_data", new_callable=AsyncMock - ) as mock_get_device_data: - mock_get_device_data.return_value = {"device": "data"} - result = await client.async_get_data() - assert client._headers["Cookie"] == "test_cookie" - assert client._logged_in is True - mock_get_device_data.assert_called_once() - assert result == {"device": "data"} - - -async def test_async_get_data_no_cookie(client, mock_session): - """Test async_get_data with no session cookie.""" - client._session_cookie = "" - with pytest.raises(InvalidCredentialsException): - await client.async_get_data() - - -async def test_get_device_data_success(client, mock_session): - """Test get_device_data success.""" - apparatus_list = [ - {"apparatusId": 1, "name": "Generator 1", "type": ALLOWED_DEVICES[0]}, - {"apparatusId": 2, "name": "Generator 2", "type": DEVICE_TYPE_UNKNOWN}, - ] - apparatus_detail = {"name": "Generator 1", "status": "Ready"} - - mock_session.get.side_effect = [ - AsyncMock(status=200, json=AsyncMock(return_value=apparatus_list)), - AsyncMock(status=200, json=AsyncMock(return_value=apparatus_detail)), - ] +def mock_auth(): + """Fixture for a GeneracAuth-like double that hands out a static AT.""" + auth = MagicMock() + auth.ensure_access_token = AsyncMock(return_value="fake-at") + return auth - result = await client.get_device_data() - assert "1" in result - assert result["1"].apparatus.name == "Generator 1" - assert result["1"].apparatusDetail.name == "Generator 1" - assert "2" not in result - - -async def test_get_device_data_no_apparatuses(client, mock_session): - """Test get_device_data with no apparatuses.""" - mock_session.get.return_value = AsyncMock( - status=200, json=AsyncMock(return_value=[]) - ) - result = await client.get_device_data() - assert result == {} - - -async def test_get_device_data_apparatus_none(client, mock_session): - """Test get_device_data with None apparatuses.""" - mock_session.get.return_value = AsyncMock( - status=200, json=AsyncMock(return_value=None) - ) - result = await client.get_device_data() - assert result is None +@pytest.fixture +def client(mock_session, mock_auth): + """Fixture for GeneracApiClient.""" + return GeneracApiClient(mock_session, mock_auth) -async def test_get_device_data_apparatus_not_a_list(client, mock_session): - """Test get_device_data with non-list apparatuses.""" - mock_session.get.return_value = AsyncMock( - status=200, json=AsyncMock(return_value={"key": "value"}) - ) - result = await client.get_device_data() - assert result == {} +async def test_init(mock_session, mock_auth): + """Test the __init__ method binds the session + auth handle.""" + client = GeneracApiClient(mock_session, mock_auth) + assert client._session is mock_session + assert client._auth is mock_auth -async def test_get_device_data_no_detail(client, mock_session): - """Test get_device_data with no detail.""" - apparatus_list = [ - {"apparatusId": 1, "name": "Generator 1", "type": ALLOWED_DEVICES[0]} - ] - mock_session.get.side_effect = [ - AsyncMock(status=200, json=AsyncMock(return_value=apparatus_list)), - AsyncMock(status=204), - ] - result = await client.get_device_data() - assert result == {} - +async def test_get_endpoint_uses_bearer_and_v5_base(client, mock_session, mock_auth): + """The Bearer token from auth.ensure_access_token() is on every request.""" + response = AsyncMock(status=200) + response.headers = {"Content-Type": "application/json"} + response.json = AsyncMock(return_value={"key": "value"}) + mock_session.get.return_value = _acm(response) -async def test_get_endpoint_success(client, mock_session): - """Test get_endpoint success.""" - mock_session.get.return_value = AsyncMock( - status=200, json=AsyncMock(return_value={"key": "value"}) - ) result = await client.get_endpoint("/test") assert result == {"key": "value"} - -async def test_get_endpoint_with_csrf(client, mock_session): - """Test get_endpoint with csrf token.""" - client.csrf = "test_csrf_token" - client._headers["Cookie"] = "test_cookie" - mock_session.get.return_value = AsyncMock( - status=200, json=AsyncMock(return_value={"key": "value"}) - ) - result = await client.get_endpoint("/test") - assert result == {"key": "value"} - mock_session.get.assert_called_with( - "https://app.mobilelinkgen.com/api/test", - headers={ - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/122.0.0.0 Safari/537.36", - "Accept": "application/json, text/plain, */*", - "Accept-Language": "en-US,en;q=0.9", - "Accept-Encoding": "gzip, deflate, br", - "Connection": "keep-alive", - "Cookie": "test_cookie", - "X-Csrf-Token": "test_csrf_token", - }, - ) + mock_auth.ensure_access_token.assert_awaited_once() + args, kwargs = mock_session.get.call_args + assert args[0] == f"{API_BASE}/test" + assert kwargs["headers"]["Authorization"] == "Bearer fake-at" + assert kwargs["headers"]["Accept"] == "application/json" async def test_get_endpoint_no_content(client, mock_session): - """Test get_endpoint with 204 No Content.""" - mock_session.get.return_value = AsyncMock(status=204) + """204 No Content returns None.""" + mock_session.get.return_value = _acm(AsyncMock(status=204)) result = await client.get_endpoint("/test") assert result is None -async def test_get_endpoint_session_expired(client, mock_session): - """Test get_endpoint with session expired.""" - mock_session.get.return_value = AsyncMock(status=401) +async def test_get_endpoint_session_expired_401(client, mock_session): + """401 from API raises SessionExpiredException.""" + mock_session.get.return_value = _acm(AsyncMock(status=401)) with pytest.raises(SessionExpiredException): await client.get_endpoint("/test") async def test_get_endpoint_server_error(client, mock_session): - """Test get_endpoint with a server error.""" - mock_session.get.return_value = AsyncMock(status=500) + """Non-2xx non-204 non-401 still raises SessionExpiredException.""" + mock_session.get.return_value = _acm(AsyncMock(status=500)) with pytest.raises(SessionExpiredException): await client.get_endpoint("/test") -async def test_get_endpoint_io_error(client, mock_session): - """Test get_endpoint with IOError.""" +async def test_get_endpoint_io_error_on_timeout(client, mock_session): + """Network timeout becomes IOError.""" mock_session.get.side_effect = asyncio.TimeoutError with pytest.raises(IOError): await client.get_endpoint("/test") -async def test_get_endpoint_invalid_content_type(client, mock_session): - """Test get_endpoint with invalid content type.""" - response = AsyncMock(status=200) - response.headers = {"Content-Type": "text/html"} - mock_session.get.return_value = response - with pytest.raises(IOError): - await client.get_endpoint("/test") - - async def test_get_endpoint_json_decode_error(client, mock_session): - """Test get_endpoint with a JSON decode error.""" + """Bad JSON becomes IOError.""" response = AsyncMock(status=200) response.headers = {"Content-Type": "application/json"} response.json = AsyncMock(side_effect=json.JSONDecodeError("msg", "doc", 0)) - mock_session.get.return_value = response + mock_session.get.return_value = _acm(response) with pytest.raises(IOError): await client.get_endpoint("/test") async def test_get_endpoint_generic_exception(client, mock_session): - """Test get_endpoint with a generic exception.""" + """Generic transport failure becomes IOError.""" mock_session.get.side_effect = Exception("A generic error occurred") with pytest.raises(IOError): await client.get_endpoint("/test") -async def test_async_get_data_session_expired(client, mock_session): - """Test async_get_data with SessionExpiredException.""" - client._session_cookie = "test_cookie" # Start with a cookie - with patch.object( - client, "get_device_data", new_callable=AsyncMock - ) as mock_get_device_data: - mock_get_device_data.side_effect = SessionExpiredException - with pytest.raises(SessionExpiredException): - await client.async_get_data() +async def test_get_endpoint_invalid_grant_maps_to_invalid_credentials( + client, mock_auth +): + """Auth0 InvalidGrant during AT mint surfaces as InvalidCredentialsException.""" + mock_auth.ensure_access_token = AsyncMock( + side_effect=InvalidGrantError("rt revoked") + ) + with pytest.raises(InvalidCredentialsException): + await client.get_endpoint("/test") + + +async def test_get_device_data_success(client, mock_session): + """Happy path: list -> details, only allowed device types kept.""" + apparatus_list = [ + {"apparatusId": 1, "name": "Generator 1", "type": ALLOWED_DEVICES[0]}, + {"apparatusId": 2, "name": "Generator 2", "type": DEVICE_TYPE_UNKNOWN}, + ] + apparatus_detail = {"name": "Generator 1", "status": "Ready"} + + list_resp = AsyncMock(status=200) + list_resp.headers = {"Content-Type": "application/json"} + list_resp.json = AsyncMock(return_value=apparatus_list) + + detail_resp = AsyncMock(status=200) + detail_resp.headers = {"Content-Type": "application/json"} + detail_resp.json = AsyncMock(return_value=apparatus_detail) + + mock_session.get.side_effect = [_acm(list_resp), _acm(detail_resp)] + + result = await client.get_device_data() + + assert "1" in result + assert result["1"].apparatus.name == "Generator 1" + assert result["1"].apparatusDetail.name == "Generator 1" + assert "2" not in result + + +async def test_get_device_data_no_apparatuses(client, mock_session): + """Empty list -> empty dict.""" + resp = AsyncMock(status=200) + resp.headers = {"Content-Type": "application/json"} + resp.json = AsyncMock(return_value=[]) + mock_session.get.return_value = _acm(resp) + result = await client.get_device_data() + assert result == {} + + +async def test_get_device_data_apparatus_none(client, mock_session): + """Decode failure on /Apparatus/list surfaces as IOError (poll failure).""" + mock_session.get.return_value = _acm(AsyncMock(status=204)) + with pytest.raises(IOError, match="Failed to decode /Apparatus/list response"): + await client.get_device_data() + + +async def test_get_device_data_apparatus_not_a_list(client, mock_session): + """Unexpected dict instead of list surfaces as IOError (poll failure).""" + resp = AsyncMock(status=200) + resp.headers = {"Content-Type": "application/json"} + resp.json = AsyncMock(return_value={"key": "value"}) + mock_session.get.return_value = _acm(resp) + with pytest.raises(IOError, match="Expected list from /Apparatus/list"): + await client.get_device_data() + + +async def test_get_device_data_no_detail(client, mock_session): + """Apparatus list returns one device but details endpoint is 204.""" + apparatus_list = [ + {"apparatusId": 1, "name": "Generator 1", "type": ALLOWED_DEVICES[0]} + ] + list_resp = AsyncMock(status=200) + list_resp.headers = {"Content-Type": "application/json"} + list_resp.json = AsyncMock(return_value=apparatus_list) + + mock_session.get.side_effect = [_acm(list_resp), _acm(AsyncMock(status=204))] + result = await client.get_device_data() + assert result == {} diff --git a/tests/test_config_flow.py b/tests/test_config_flow.py index 4dbd8fd..9ef3e13 100644 --- a/tests/test_config_flow.py +++ b/tests/test_config_flow.py @@ -1,8 +1,16 @@ """Test the Generac config flow.""" + +from unittest.mock import AsyncMock +from unittest.mock import MagicMock from unittest.mock import patch import pytest -from custom_components.generac.api import InvalidCredentialsException +from custom_components.generac.auth import DPoPKey +from custom_components.generac.auth import InvalidCredentialsError +from custom_components.generac.const import CONF_DPOP_PEM +from custom_components.generac.const import CONF_PASSWORD +from custom_components.generac.const import CONF_REFRESH_TOKEN +from custom_components.generac.const import CONF_USERNAME from custom_components.generac.const import DOMAIN from homeassistant import config_entries from homeassistant import setup @@ -10,8 +18,17 @@ from pytest_homeassistant_custom_component.common import MockConfigEntry -async def test_form(hass: HomeAssistant) -> None: - """Test we get the form.""" +def _mock_auth(refresh_token: str = "rt-abc", email: str = "user@example.com"): + """Build a fake GeneracAuth-like object that login() returns.""" + auth = MagicMock() + auth.refresh_token = refresh_token + auth.pem_str = DPoPKey.generate().to_pem_str() + auth.email = email + return auth + + +async def test_form_user(hass: HomeAssistant) -> None: + """User submits valid email+password and the entry is created.""" await setup.async_setup_component(hass, "persistent_notification", {}) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} @@ -19,9 +36,10 @@ async def test_form(hass: HomeAssistant) -> None: assert result["type"] == "form" assert result["errors"] == {} + fake_auth = _mock_auth() with patch( - "custom_components.generac.config_flow.GeneracApiClient.async_get_data", - return_value=True, + "custom_components.generac.config_flow.GeneracAuth.login", + AsyncMock(return_value=fake_auth), ), patch( "custom_components.generac.async_setup_entry", return_value=True, @@ -29,35 +47,35 @@ async def test_form(hass: HomeAssistant) -> None: result2 = await hass.config_entries.flow.async_configure( result["flow_id"], { - "session_cookie": "MobileLinkClientCookie=%7B%0D%0A%20%20%22signInName%22%3A%20%22binarydev%40testing.com%22%0D%0A%7D", + CONF_USERNAME: "user@example.com", + CONF_PASSWORD: "hunter2", }, ) await hass.async_block_till_done() assert result2["type"] == "create_entry" - assert result2["title"] == "binarydev@testing.com" - assert result2["data"] == { - "session_cookie": "MobileLinkClientCookie=%7B%0D%0A%20%20%22signInName%22%3A%20%22binarydev%40testing.com%22%0D%0A%7D", - } + assert result2["title"] == "user@example.com" + assert result2["data"][CONF_USERNAME] == "user@example.com" + assert result2["data"][CONF_REFRESH_TOKEN] == "rt-abc" + assert CONF_DPOP_PEM in result2["data"] + assert CONF_PASSWORD not in result2["data"] assert len(mock_setup_entry.mock_calls) == 1 async def test_form_invalid_auth(hass: HomeAssistant) -> None: - """Test we handle invalid auth.""" + """Invalid credentials surface as a form error.""" await setup.async_setup_component(hass, "persistent_notification", {}) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) with patch( - "custom_components.generac.config_flow.GeneracApiClient.async_get_data", - side_effect=InvalidCredentialsException, + "custom_components.generac.config_flow.GeneracAuth.login", + AsyncMock(side_effect=InvalidCredentialsError("bad creds")), ): result2 = await hass.config_entries.flow.async_configure( result["flow_id"], - { - "session_cookie": "bad-cookie", - }, + {CONF_USERNAME: "user@example.com", CONF_PASSWORD: "wrong"}, ) assert result2["type"] == "form" @@ -65,67 +83,50 @@ async def test_form_invalid_auth(hass: HomeAssistant) -> None: async def test_form_internal_error(hass: HomeAssistant) -> None: - """Test we handle an internal error.""" + """Unexpected exception surfaces as internal error.""" await setup.async_setup_component(hass, "persistent_notification", {}) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) with patch( - "custom_components.generac.config_flow.GeneracApiClient.async_get_data", - side_effect=Exception, + "custom_components.generac.config_flow.GeneracAuth.login", + AsyncMock(side_effect=RuntimeError("boom")), ): result2 = await hass.config_entries.flow.async_configure( result["flow_id"], - { - "session_cookie": "bad-cookie", - }, + {CONF_USERNAME: "user@example.com", CONF_PASSWORD: "any"}, ) assert result2["type"] == "form" assert result2["errors"] == {"base": "internal"} -async def test_form_malformed_cookie(hass: HomeAssistant) -> None: - """Test we handle a malformed cookie.""" - await setup.async_setup_component(hass, "persistent_notification", {}) - result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER} +async def test_duplicate_entry(hass: HomeAssistant) -> None: + """Same email twice should abort as already_configured.""" + existing = MockConfigEntry( + domain=DOMAIN, + unique_id="user@example.com", + data={CONF_USERNAME: "user@example.com"}, ) + existing.add_to_hass(hass) - with patch( - "custom_components.generac.config_flow.GeneracApiClient.async_get_data", - return_value=True, - ): - result2 = await hass.config_entries.flow.async_configure( - result["flow_id"], - { - "session_cookie": "MobileLinkClientCookie=not-json", - }, - ) - - assert result2["type"] == "create_entry" - - -async def test_form_no_cookie(hass: HomeAssistant) -> None: - """Test we handle a cookie with no signin name.""" await setup.async_setup_component(hass, "persistent_notification", {}) result = await hass.config_entries.flow.async_init( DOMAIN, context={"source": config_entries.SOURCE_USER} ) with patch( - "custom_components.generac.config_flow.GeneracApiClient.async_get_data", - return_value=True, + "custom_components.generac.config_flow.GeneracAuth.login", + AsyncMock(return_value=_mock_auth()), ): result2 = await hass.config_entries.flow.async_configure( result["flow_id"], - { - "session_cookie": "foo=bar", - }, + {CONF_USERNAME: "user@example.com", CONF_PASSWORD: "any"}, ) - assert result2["type"] == "create_entry" + assert result2["type"] == "abort" + assert result2["reason"] == "already_configured" @pytest.mark.asyncio @@ -148,7 +149,6 @@ async def test_options_flow(hass: HomeAssistant) -> None: await hass.async_block_till_done() result = await hass.config_entries.options.async_init(entry.entry_id) - assert result["type"] == "form" assert result["step_id"] == "user" @@ -168,9 +168,16 @@ async def test_options_flow(hass: HomeAssistant) -> None: @pytest.mark.asyncio async def test_reconfigure_flow(hass: HomeAssistant) -> None: - """Test the reconfigure flow.""" + """Reconfigure should re-run login and update entry data.""" + pem = DPoPKey.generate().to_pem_str() entry = MockConfigEntry( - domain=DOMAIN, data={"session_cookie": "old_cookie"}, options={} + domain=DOMAIN, + data={ + CONF_USERNAME: "user@example.com", + CONF_REFRESH_TOKEN: "old-rt", + CONF_DPOP_PEM: pem, + }, + options={}, ) entry.add_to_hass(hass) @@ -180,58 +187,63 @@ async def test_reconfigure_flow(hass: HomeAssistant) -> None: result = await hass.config_entries.flow.async_init( DOMAIN, - context={ - "source": "reconfigure", - "entry_id": entry.entry_id, - }, + context={"source": "reconfigure", "entry_id": entry.entry_id}, ) assert result["type"] == "form" assert result["step_id"] == "reconfigure" + new_auth = _mock_auth(refresh_token="new-rt") with patch( - "custom_components.generac.config_flow.GeneracApiClient.async_get_data", - return_value=True, + "custom_components.generac.config_flow.GeneracAuth.login", + AsyncMock(return_value=new_auth), ), patch("custom_components.generac.async_setup_entry", return_value=True), patch( "custom_components.generac.async_unload_entry", return_value=True ): result2 = await hass.config_entries.flow.async_configure( result["flow_id"], - { - "session_cookie": "new_cookie", - }, + {CONF_USERNAME: "user@example.com", CONF_PASSWORD: "new-pw"}, ) await hass.async_block_till_done() assert result2["type"] == "abort" - assert result2["reason"] == "Reconfigure Successful" - assert entry.data["session_cookie"] == "new_cookie" + assert entry.data[CONF_REFRESH_TOKEN] == "new-rt" -async def test_duplicate_entry(hass: HomeAssistant) -> None: - """Test duplicate entry is handled.""" +async def test_reauth_flow(hass: HomeAssistant) -> None: + """Reauth should re-prompt password (email locked) and update credentials.""" + pem = DPoPKey.generate().to_pem_str() entry = MockConfigEntry( domain=DOMAIN, - unique_id="binarydev@testing.com", - data={"session_cookie": "existing"}, + unique_id="user@example.com", + data={ + CONF_USERNAME: "user@example.com", + CONF_REFRESH_TOKEN: "stale-rt", + CONF_DPOP_PEM: pem, + }, ) entry.add_to_hass(hass) - await setup.async_setup_component(hass, "persistent_notification", {}) result = await hass.config_entries.flow.async_init( - DOMAIN, context={"source": config_entries.SOURCE_USER} + DOMAIN, + context={"source": "reauth", "entry_id": entry.entry_id}, + data=entry.data, ) + assert result["type"] == "form" + assert result["step_id"] == "reauth_confirm" + + new_auth = _mock_auth(refresh_token="fresh-rt") with patch( - "custom_components.generac.config_flow.GeneracApiClient.async_get_data", - return_value=True, - ): + "custom_components.generac.config_flow.GeneracAuth.login", + AsyncMock(return_value=new_auth), + ), patch("custom_components.generac.async_setup_entry", return_value=True): result2 = await hass.config_entries.flow.async_configure( result["flow_id"], - { - "session_cookie": "MobileLinkClientCookie=%7B%0D%0A%20%20%22signInName%22%3A%20%22binarydev%40testing.com%22%0D%0A%7D", - }, + {CONF_PASSWORD: "new-pw"}, ) + await hass.async_block_till_done() assert result2["type"] == "abort" - assert result2["reason"] == "already_configured" + assert result2["reason"] == "reauth_successful" + assert entry.data[CONF_REFRESH_TOKEN] == "fresh-rt" diff --git a/tests/test_entity.py b/tests/test_entity.py index 59a9e88..f8fc850 100644 --- a/tests/test_entity.py +++ b/tests/test_entity.py @@ -1,4 +1,5 @@ """Test the Generac entity.""" + from unittest.mock import MagicMock from custom_components.generac.entity import GeneracEntity @@ -34,7 +35,7 @@ async def test_entity(hass): "model": "G12345", "manufacturer": "Generac", } - assert entity.device_state_attributes == { + assert entity.extra_state_attributes == { "attribution": "Data provided by https://app.mobilelinkgen.com/api. This is reversed engineered. Heavily inspired by https://github.com/digitaldan/openhab-addons/blob/generac-2.0/bundles/org.openhab.binding.generacmobilelink/README.md", "id": "12345", "integration": "generac", diff --git a/tests/test_init.py b/tests/test_init.py index 247ef70..b24d6a8 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,57 +1,99 @@ """Test generac setup process.""" + from unittest.mock import patch import pytest from custom_components.generac import async_reload_entry from custom_components.generac import async_setup_entry from custom_components.generac import async_unload_entry +from custom_components.generac.auth import DPoPKey +from custom_components.generac.const import CONF_DPOP_PEM +from custom_components.generac.const import CONF_REFRESH_TOKEN +from custom_components.generac.const import CONF_USERNAME from custom_components.generac.const import DOMAIN from homeassistant.core import HomeAssistant +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.exceptions import ConfigEntryNotReady from pytest_homeassistant_custom_component.common import MockConfigEntry -MOCK_CONFIG = {"session_cookie": "test_cookie"} + +def _make_mock_config(): + """Build a valid v1 entry with a real DPoP PEM.""" + return { + CONF_USERNAME: "user@example.com", + CONF_REFRESH_TOKEN: "fake-refresh-token", + CONF_DPOP_PEM: DPoPKey.generate().to_pem_str(), + } async def test_setup_unload_and_reload_entry(hass: HomeAssistant, bypass_get_data): """Test entry setup and unload.""" - # Create a mock entry so we don't have to go through config flow - config_entry = MockConfigEntry(domain=DOMAIN, data=MOCK_CONFIG, entry_id="test") + config_entry = MockConfigEntry( + domain=DOMAIN, data=_make_mock_config(), entry_id="test" + ) config_entry.add_to_hass(hass) - # Set up the entry and assert that the values set during setup are where we expect - # them to be. Because we have a mock coordinator, none of the values is actually - # filled in. await hass.config_entries.async_setup(config_entry.entry_id) assert await async_setup_entry(hass, config_entry) assert DOMAIN in hass.data and config_entry.entry_id in hass.data[DOMAIN] - # Reload the entry and assert that the data from above is still there assert await async_reload_entry(hass, config_entry) is None assert DOMAIN in hass.data and config_entry.entry_id in hass.data[DOMAIN] - # Unload the entry and verify that the data has been removed assert await async_unload_entry(hass, config_entry) assert config_entry.entry_id not in hass.data[DOMAIN] async def test_setup_entry_exception(hass: HomeAssistant, error_on_get_data): - """Test config entry not ready.""" - # Create a mock entry so we don't have to go through config flow - config_entry = MockConfigEntry(domain=DOMAIN, data=MOCK_CONFIG, entry_id="test") + """Test config entry not ready when API errors.""" + config_entry = MockConfigEntry( + domain=DOMAIN, data=_make_mock_config(), entry_id="test" + ) config_entry.add_to_hass(hass) - # In this case we are testing the condition where async_setup_entry raises - # ConfigEntryNotReady using the `error_on_get_data` fixture which simulates - # an error fetching the data. with pytest.raises(ConfigEntryNotReady): await async_setup_entry(hass, config_entry) +async def test_setup_entry_missing_refresh_token(hass: HomeAssistant): + """Missing refresh_token in entry data should trigger reauth.""" + bad = _make_mock_config() + del bad[CONF_REFRESH_TOKEN] + config_entry = MockConfigEntry(domain=DOMAIN, data=bad, entry_id="test") + config_entry.add_to_hass(hass) + + with pytest.raises(ConfigEntryAuthFailed): + await async_setup_entry(hass, config_entry) + + +async def test_setup_entry_missing_pem(hass: HomeAssistant): + """Missing DPoP PEM in entry data should trigger reauth.""" + bad = _make_mock_config() + del bad[CONF_DPOP_PEM] + config_entry = MockConfigEntry(domain=DOMAIN, data=bad, entry_id="test") + config_entry.add_to_hass(hass) + + with pytest.raises(ConfigEntryAuthFailed): + await async_setup_entry(hass, config_entry) + + +async def test_setup_entry_corrupt_pem(hass: HomeAssistant): + """Corrupt PEM string should trigger reauth, not crash.""" + bad = _make_mock_config() + bad[CONF_DPOP_PEM] = "not-a-pem" + config_entry = MockConfigEntry(domain=DOMAIN, data=bad, entry_id="test") + config_entry.add_to_hass(hass) + + with pytest.raises(ConfigEntryAuthFailed): + await async_setup_entry(hass, config_entry) + + async def test_setup_entry_existing_domain(hass: HomeAssistant, bypass_get_data): """Test entry setup with existing domain data.""" hass.data[DOMAIN] = {"existing_entry": "data"} - config_entry = MockConfigEntry(domain=DOMAIN, data=MOCK_CONFIG, entry_id="test") + config_entry = MockConfigEntry( + domain=DOMAIN, data=_make_mock_config(), entry_id="test" + ) config_entry.add_to_hass(hass) await hass.config_entries.async_setup(config_entry.entry_id) @@ -62,7 +104,9 @@ async def test_setup_entry_existing_domain(hass: HomeAssistant, bypass_get_data) async def test_unload_entry_failed(hass: HomeAssistant, bypass_get_data): """Test entry unload failed.""" - config_entry = MockConfigEntry(domain=DOMAIN, data=MOCK_CONFIG, entry_id="test") + config_entry = MockConfigEntry( + domain=DOMAIN, data=_make_mock_config(), entry_id="test" + ) config_entry.add_to_hass(hass) await hass.config_entries.async_setup(config_entry.entry_id) diff --git a/tests/test_sensor.py b/tests/test_sensor.py index 59be2bf..f5e5c76 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -1,4 +1,5 @@ """Test the Generac sensor platform.""" + from unittest.mock import MagicMock from custom_components.generac.const import DEVICE_TYPE_GENERATOR @@ -102,16 +103,18 @@ def get_mock_item( MagicMock(type=11, value=last_reading_date), MagicMock(type=17, value=battery_level), ], - weather=Weather( - temperature=Weather.Temperature( - value=outdoor_temperature, - unit=outdoor_temperature_unit, - unitType=outdoor_temperature_unit_type, - ), - iconCode=weather_icon_code, - ) - if outdoor_temperature is not None - else None, + weather=( + Weather( + temperature=Weather.Temperature( + value=outdoor_temperature, + unit=outdoor_temperature_unit, + unitType=outdoor_temperature_unit_type, + ), + iconCode=weather_icon_code, + ) + if outdoor_temperature is not None + else None + ), ), ) From b683df366043e5403107442990e80fe0f43323e2 Mon Sep 17 00:00:00 2001 From: Stefan Slivinski Date: Wed, 29 Apr 2026 20:53:41 -0700 Subject: [PATCH 2/3] test: add P0/P1 coverage for auth, coordinator, init, config flow, entity, sensors - auth.py: refresh-token success/forbidden/invalid_grant/invalid_request paths plus 7 form-parser cases for the legacy CSRF login flow. - coordinator.py: default + override scan_interval; InvalidCredentialsException and InvalidGrantError both surface as ConfigEntryAuthFailed. - __init__.py: persist callback wired into auth on setup; first-refresh InvalidCredentialsException / InvalidGrantError raise ConfigEntryAuthFailed. - config_flow.py: reauth flow keeps original email (not user-supplied) when calling login(); reconfigure persists CONF_SCAN_INTERVAL into entry.options. - entity.py: missing device id falls back to _EMPTY_ITEM (available=False). - sensor.py: _safe_float treats None / 'N/A' / non-numeric as None instead of crashing native_value (RunTime, ProtectionTime, BatteryVoltage). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/test_auth.py | 213 ++++++++++++++++++++++++++++++++++++++ tests/test_config_flow.py | 92 ++++++++++++++++ tests/test_coordinator.py | 46 ++++++++ tests/test_entity.py | 23 ++++ tests/test_init.py | 62 +++++++++++ tests/test_sensor.py | 16 +++ 6 files changed, 452 insertions(+) create mode 100644 tests/test_auth.py diff --git a/tests/test_auth.py b/tests/test_auth.py new file mode 100644 index 0000000..e72336c --- /dev/null +++ b/tests/test_auth.py @@ -0,0 +1,213 @@ +"""Tests for GeneracAuth refresh + Auth0 ULP form-parser error handling.""" +from __future__ import annotations + +import json +from unittest.mock import AsyncMock +from unittest.mock import MagicMock + +import pytest +from custom_components.generac.auth import _post_login_form +from custom_components.generac.auth import DPoPKey +from custom_components.generac.auth import GeneracAuth +from custom_components.generac.auth import InvalidCredentialsError +from custom_components.generac.auth import InvalidGrantError + + +def _acm(resp): + """Wrap a response object as an async context manager.""" + cm = MagicMock() + cm.__aenter__ = AsyncMock(return_value=resp) + cm.__aexit__ = AsyncMock(return_value=None) + return cm + + +def _token_resp(status: int, body: dict, dpop_nonce: str | None = None): + resp = MagicMock() + resp.status = status + resp.text = AsyncMock(return_value=json.dumps(body)) + resp.headers = MagicMock() + resp.headers.get = MagicMock(return_value=dpop_nonce) + return resp + + +def _make_auth(refresh_token: str = "rt-OLD") -> GeneracAuth: + session = MagicMock() + key = DPoPKey.generate() + return GeneracAuth(session, refresh_token, key, email="user@example.com") + + +# --------------------------------------------------------------------------- +# P0: refresh-token persist callback + invalid_grant handling +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_refresh_persist_callback_fires_on_rotation(): + """When Auth0 rotates the refresh token, the persist callback must run.""" + auth = _make_auth(refresh_token="rt-OLD") + auth._session.post = MagicMock( + return_value=_acm( + _token_resp( + 200, + { + "access_token": "AT-1", + "expires_in": 3600, + "refresh_token": "rt-NEW", + "scope": "openid", + "token_type": "Bearer", + }, + ) + ) + ) + + persist_cb = AsyncMock() + auth.set_refresh_token_persist_callback(persist_cb) + + await auth._refresh() + + assert auth._access_token == "AT-1" + assert auth._refresh_token == "rt-NEW" + persist_cb.assert_awaited_once_with("rt-NEW") + + +@pytest.mark.asyncio +async def test_refresh_persist_callback_not_called_when_rt_unchanged(): + """If the server returns the same RT (no rotation), do NOT call the persist cb.""" + auth = _make_auth(refresh_token="rt-OLD") + auth._session.post = MagicMock( + return_value=_acm( + _token_resp( + 200, + { + "access_token": "AT-1", + "expires_in": 3600, + "refresh_token": "rt-OLD", + }, + ) + ) + ) + + persist_cb = AsyncMock() + auth.set_refresh_token_persist_callback(persist_cb) + + await auth._refresh() + + assert auth._refresh_token == "rt-OLD" + persist_cb.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_refresh_persist_callback_not_called_when_omitted(): + """If the server omits refresh_token entirely, do NOT call the persist cb.""" + auth = _make_auth(refresh_token="rt-OLD") + auth._session.post = MagicMock( + return_value=_acm( + _token_resp( + 200, + {"access_token": "AT-1", "expires_in": 3600}, + ) + ) + ) + + persist_cb = AsyncMock() + auth.set_refresh_token_persist_callback(persist_cb) + + await auth._refresh() + + assert auth._refresh_token == "rt-OLD" + persist_cb.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_refresh_invalid_grant_raises(): + """A 400 invalid_grant from the token endpoint must raise InvalidGrantError.""" + auth = _make_auth(refresh_token="rt-REVOKED") + auth._session.post = MagicMock( + return_value=_acm( + _token_resp( + 400, + {"error": "invalid_grant", "error_description": "rt revoked"}, + ) + ) + ) + + with pytest.raises(InvalidGrantError): + await auth._refresh() + + +# --------------------------------------------------------------------------- +# P1: Auth0 ULP form-parser error code mapping +# --------------------------------------------------------------------------- + + +def _ulp_error_resp(status: int, code: str | None): + """Build a fake Auth0 ULP HTML response advertising a data-error-code.""" + resp = MagicMock() + resp.status = status + if code is not None: + html = ( + '
err
' + ) + else: + html = "some other failure" + resp.text = AsyncMock(return_value=html) + resp.headers = {"Location": "/dummy"} + return resp + + +@pytest.mark.parametrize( + "code", + [ + "invalid-password", + "invalid_credentials", + "user-not-found", + "account-locked", + "user-blocked", + ], +) +@pytest.mark.asyncio +async def test_post_login_form_credential_codes_raise_invalid_credentials(code): + """Auth0 codes containing credential keywords must surface InvalidCredentialsError.""" + session = MagicMock() + session.post = MagicMock(return_value=_acm(_ulp_error_resp(400, code))) + + with pytest.raises(InvalidCredentialsError): + await _post_login_form( + session, + "https://auth.ecobee.com/u/login/password", + "state-xyz", + {"password": "x"}, + ) + + +@pytest.mark.asyncio +async def test_post_login_form_unknown_code_raises_runtime_error(): + """An Auth0 error code without credential keywords becomes RuntimeError.""" + session = MagicMock() + session.post = MagicMock(return_value=_acm(_ulp_error_resp(400, "transient-503"))) + + with pytest.raises(RuntimeError) as exc_info: + await _post_login_form( + session, + "https://auth.ecobee.com/u/login/password", + "state-xyz", + {"password": "x"}, + ) + assert not isinstance(exc_info.value, InvalidCredentialsError) + + +@pytest.mark.asyncio +async def test_post_login_form_no_code_raises_runtime_error(): + """A non-redirect with no parseable error code becomes a plain RuntimeError.""" + session = MagicMock() + session.post = MagicMock(return_value=_acm(_ulp_error_resp(400, None))) + + with pytest.raises(RuntimeError) as exc_info: + await _post_login_form( + session, + "https://auth.ecobee.com/u/login/password", + "state-xyz", + {"password": "x"}, + ) + assert not isinstance(exc_info.value, InvalidCredentialsError) diff --git a/tests/test_config_flow.py b/tests/test_config_flow.py index 9ef3e13..10b8b03 100644 --- a/tests/test_config_flow.py +++ b/tests/test_config_flow.py @@ -247,3 +247,95 @@ async def test_reauth_flow(hass: HomeAssistant) -> None: assert result2["type"] == "abort" assert result2["reason"] == "reauth_successful" assert entry.data[CONF_REFRESH_TOKEN] == "fresh-rt" + + +async def test_reauth_flow_locks_email_to_entry(hass: HomeAssistant) -> None: + """Reauth re-uses the entry's stored email even if user-supplied data has none.""" + pem = DPoPKey.generate().to_pem_str() + entry = MockConfigEntry( + domain=DOMAIN, + unique_id="locked@example.com", + data={ + CONF_USERNAME: "locked@example.com", + CONF_REFRESH_TOKEN: "stale-rt", + CONF_DPOP_PEM: pem, + }, + ) + entry.add_to_hass(hass) + + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": "reauth", "entry_id": entry.entry_id}, + data=entry.data, + ) + assert result["step_id"] == "reauth_confirm" + + new_auth = _mock_auth(refresh_token="fresh-rt", email="locked@example.com") + login_mock = AsyncMock(return_value=new_auth) + with patch( + "custom_components.generac.config_flow.GeneracAuth.login", + login_mock, + ), patch("custom_components.generac.async_setup_entry", return_value=True): + await hass.config_entries.flow.async_configure( + result["flow_id"], + {CONF_PASSWORD: "new-pw"}, + ) + await hass.async_block_till_done() + + # login() must have been called with the entry's stored email, not anything user-supplied. + assert login_mock.await_count == 1 + call_args = login_mock.await_args + # email is the 2nd positional arg in GeneracAuth.login(session, email, password, ...) + args = call_args.args + kwargs = call_args.kwargs + if len(args) >= 2: + used_email = args[1] + else: + used_email = kwargs.get("email") + assert used_email == "locked@example.com" + + +async def test_reconfigure_flow_persists_scan_interval(hass: HomeAssistant) -> None: + """Scan interval supplied during reconfigure ends up in entry.options.""" + from custom_components.generac.const import CONF_SCAN_INTERVAL + + pem = DPoPKey.generate().to_pem_str() + entry = MockConfigEntry( + domain=DOMAIN, + data={ + CONF_USERNAME: "user@example.com", + CONF_REFRESH_TOKEN: "old-rt", + CONF_DPOP_PEM: pem, + }, + options={}, + ) + entry.add_to_hass(hass) + + with patch("custom_components.generac.async_setup_entry", return_value=True): + await hass.config_entries.async_setup(entry.entry_id) + await hass.async_block_till_done() + + result = await hass.config_entries.flow.async_init( + DOMAIN, + context={"source": "reconfigure", "entry_id": entry.entry_id}, + ) + assert result["step_id"] == "reconfigure" + + new_auth = _mock_auth(refresh_token="new-rt") + with patch( + "custom_components.generac.config_flow.GeneracAuth.login", + AsyncMock(return_value=new_auth), + ), patch("custom_components.generac.async_setup_entry", return_value=True), patch( + "custom_components.generac.async_unload_entry", return_value=True + ): + await hass.config_entries.flow.async_configure( + result["flow_id"], + { + CONF_USERNAME: "user@example.com", + CONF_PASSWORD: "new-pw", + CONF_SCAN_INTERVAL: 600, + }, + ) + await hass.async_block_till_done() + + assert entry.options.get(CONF_SCAN_INTERVAL) == 600 diff --git a/tests/test_coordinator.py b/tests/test_coordinator.py index 5ea2c59..1ae6571 100644 --- a/tests/test_coordinator.py +++ b/tests/test_coordinator.py @@ -1,9 +1,15 @@ """Test the Generac data update coordinator.""" +from datetime import timedelta from unittest.mock import AsyncMock from unittest.mock import MagicMock import pytest +from custom_components.generac.api import InvalidCredentialsException +from custom_components.generac.auth import InvalidGrantError +from custom_components.generac.const import CONF_SCAN_INTERVAL +from custom_components.generac.const import DEFAULT_SCAN_INTERVAL from custom_components.generac.coordinator import GeneracDataUpdateCoordinator +from homeassistant.exceptions import ConfigEntryAuthFailed from homeassistant.helpers.update_coordinator import UpdateFailed @@ -41,3 +47,43 @@ async def test_coordinator_update_data_fails(hass): with pytest.raises(UpdateFailed): await coordinator._async_update_data() assert not coordinator.is_online + + +async def test_coordinator_uses_default_scan_interval(hass): + """Without a scan_interval option, coordinator picks DEFAULT_SCAN_INTERVAL.""" + config_entry = MagicMock() + config_entry.options = {} + coordinator = GeneracDataUpdateCoordinator(hass, MagicMock(), config_entry) + assert coordinator.update_interval == timedelta(seconds=DEFAULT_SCAN_INTERVAL) + + +async def test_coordinator_honors_scan_interval_option(hass): + """A scan_interval option overrides the default.""" + config_entry = MagicMock() + config_entry.options = {CONF_SCAN_INTERVAL: 120} + coordinator = GeneracDataUpdateCoordinator(hass, MagicMock(), config_entry) + assert coordinator.update_interval == timedelta(seconds=120) + + +async def test_coordinator_invalid_credentials_raises_auth_failed(hass): + """InvalidCredentialsException → ConfigEntryAuthFailed (triggers reauth).""" + config_entry = MagicMock() + config_entry.options = {} + client = MagicMock() + client.async_get_data = AsyncMock(side_effect=InvalidCredentialsException("bad")) + coordinator = GeneracDataUpdateCoordinator(hass, client, config_entry) + with pytest.raises(ConfigEntryAuthFailed): + await coordinator._async_update_data() + assert not coordinator.is_online + + +async def test_coordinator_invalid_grant_raises_auth_failed(hass): + """InvalidGrantError → ConfigEntryAuthFailed (triggers reauth).""" + config_entry = MagicMock() + config_entry.options = {} + client = MagicMock() + client.async_get_data = AsyncMock(side_effect=InvalidGrantError("revoked")) + coordinator = GeneracDataUpdateCoordinator(hass, client, config_entry) + with pytest.raises(ConfigEntryAuthFailed): + await coordinator._async_update_data() + assert not coordinator.is_online diff --git a/tests/test_entity.py b/tests/test_entity.py index f8fc850..2ae05b1 100644 --- a/tests/test_entity.py +++ b/tests/test_entity.py @@ -54,3 +54,26 @@ async def test_entity(hass): entity._handle_coordinator_update() assert entity.item == new_item assert entity.async_write_ha_state.called + + +async def test_entity_missing_device_falls_back_to_empty_item(hass): + """If a device disappears from coordinator.data the entity stays alive but unavailable.""" + from custom_components.generac.entity import _EMPTY_ITEM + + coordinator = MagicMock() + coordinator.is_online = True + entry = MagicMock() + entry.entry_id = "test_entry_id" + + item = get_mock_item() + entity = GeneracEntity(coordinator, entry, "12345", item) + entity.hass = hass + entity.async_write_ha_state = MagicMock() + + # Simulate a coordinator refresh where this device is no longer reported. + coordinator.data = {} + entity._handle_coordinator_update() + + assert entity.item is _EMPTY_ITEM + assert entity.available is False + assert entity.async_write_ha_state.called diff --git a/tests/test_init.py b/tests/test_init.py index b24d6a8..5534477 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -118,3 +118,65 @@ async def test_unload_entry_failed(hass: HomeAssistant, bypass_get_data): ): assert not await async_unload_entry(hass, config_entry) assert config_entry.entry_id in hass.data[DOMAIN] + + +async def test_setup_entry_persist_callback_registered(hass, bypass_get_data): + """Setup wires the entry-update persist callback into GeneracAuth.""" + config_entry = MockConfigEntry( + domain=DOMAIN, data=_make_mock_config(), entry_id="test" + ) + config_entry.add_to_hass(hass) + + captured = {} + + def fake_from_storage(session, refresh_token, pem_str, email=None): + auth = type("FakeAuth", (), {})() + auth.set_refresh_token_persist_callback = lambda cb: captured.setdefault( + "cb", cb + ) + return auth + + with patch( + "custom_components.generac.GeneracAuth.from_storage", + side_effect=fake_from_storage, + ): + assert await async_setup_entry(hass, config_entry) + assert callable(captured.get("cb")) + + +async def test_setup_entry_invalid_credentials_raises_auth_failed(hass): + """First refresh raising InvalidCredentialsException → ConfigEntryAuthFailed.""" + from custom_components.generac.api import InvalidCredentialsException + + config_entry = MockConfigEntry( + domain=DOMAIN, data=_make_mock_config(), entry_id="test" + ) + config_entry.add_to_hass(hass) + + async def boom(self): + raise InvalidCredentialsException("nope") + + with patch( + "custom_components.generac.coordinator.GeneracDataUpdateCoordinator.async_config_entry_first_refresh", + boom, + ), pytest.raises(ConfigEntryAuthFailed): + await async_setup_entry(hass, config_entry) + + +async def test_setup_entry_invalid_grant_raises_auth_failed(hass): + """First refresh raising InvalidGrantError → ConfigEntryAuthFailed.""" + from custom_components.generac.auth import InvalidGrantError + + config_entry = MockConfigEntry( + domain=DOMAIN, data=_make_mock_config(), entry_id="test" + ) + config_entry.add_to_hass(hass) + + async def boom(self): + raise InvalidGrantError("revoked") + + with patch( + "custom_components.generac.coordinator.GeneracDataUpdateCoordinator.async_config_entry_first_refresh", + boom, + ), pytest.raises(ConfigEntryAuthFailed): + await async_setup_entry(hass, config_entry) diff --git a/tests/test_sensor.py b/tests/test_sensor.py index f5e5c76..344b974 100644 --- a/tests/test_sensor.py +++ b/tests/test_sensor.py @@ -270,3 +270,19 @@ async def test_propane_monitor_sensors(hass): sensor = DeviceTypeSensor(coordinator, entry, "12345", item) assert sensor.native_value == "lte-tankutility-v2" + + +async def test_safe_float_handles_none_and_invalid(hass): + """Bad sensor values (None, 'N/A', non-numeric) yield None instead of crashing.""" + coordinator = MagicMock() + entry = MagicMock() + item = get_mock_item(DEVICE_TYPE_GENERATOR, 1) + item.apparatusDetail.properties = [ + MagicMock(type=71, value=None), + MagicMock(type=32, value="N/A"), + MagicMock(type=70, value="not-a-number"), + ] + + assert RunTimeSensor(coordinator, entry, "12345", item).native_value is None + assert ProtectionTimeSensor(coordinator, entry, "12345", item).native_value is None + assert BatteryVoltageSensor(coordinator, entry, "12345", item).native_value is None From 8c550caf7320058dab5311d550fcb463602cb04a Mon Sep 17 00:00:00 2001 From: Stefan Slivinski Date: Wed, 29 Apr 2026 20:57:00 -0700 Subject: [PATCH 3/3] test: fix test_setup_entry_persist_callback_registered + reorder imports - Use hass.config_entries.async_setup() so entry transitions through SETUP_IN_PROGRESS, matching the pattern of existing tests in this file. Calling async_setup_entry directly when entry is NOT_LOADED triggers OperationNotAllowed in async_forward_entry_setups. - Remove blank line after module docstring per reorder-python-imports. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- tests/test_init.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_init.py b/tests/test_init.py index 5534477..816b2f4 100644 --- a/tests/test_init.py +++ b/tests/test_init.py @@ -1,5 +1,4 @@ """Test generac setup process.""" - from unittest.mock import patch import pytest @@ -140,7 +139,8 @@ def fake_from_storage(session, refresh_token, pem_str, email=None): "custom_components.generac.GeneracAuth.from_storage", side_effect=fake_from_storage, ): - assert await async_setup_entry(hass, config_entry) + await hass.config_entries.async_setup(config_entry.entry_id) + await hass.async_block_till_done() assert callable(captured.get("cb"))