diff --git a/.gitignore b/.gitignore index 7d8dee0..4629bfc 100644 --- a/.gitignore +++ b/.gitignore @@ -54,3 +54,5 @@ migration.py .github .pytest_cache .ruff_cache + +*.local.* diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..e5ac8af --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,105 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project + +Watchly is a Stremio catalog addon that generates personalized movie/series recommendations from a user's watch history. It is a FastAPI service that speaks the Stremio addon protocol (manifest + catalog endpoints). Recommendations come from a taste profile built off the user's history, then candidates are pulled from TMDB / Simkl, scored, capped for diversity, enriched, and returned as a Stremio catalog. + +A user installs Watchly through its `/configure` web page: they paste a Stremio email/password (or auth_key), optionally connect Trakt and/or Simkl via OAuth, optionally provide their own TMDB / Gemini / Simkl / RPDB API keys, pick which catalogs they want, and get an addon manifest URL to paste into Stremio. From then on, every catalog row in their Stremio home — "Top Picks for You", "Because you loved …", "Genre & Keyword Catalogs", etc. — is served by this app. State per user is keyed on a short opaque token embedded in the manifest URL; credentials are encrypted at rest in Redis. The app must work for users who store their library in Stremio, in Trakt, or in Simkl, and for users with mixed signals (rated, watched, loved, rewatched). That source flexibility is the central architectural constraint. + +## Commands + +Dependencies are managed with [uv](https://github.com/astral-sh/uv); a `requirements.txt` is also kept in sync for non-uv environments. Python 3.12+. + +```bash +# Install +uv sync + +# Run dev server (auto-reload when APP_ENV=development) +uv run main.py --dev +# or directly +uvicorn app.core.app:app --reload + +# Tests (pytest is not in requirements-dev.txt — install once into the venv) +pip install pytest pytest-asyncio +pytest tests/ # all tests +pytest tests/test_catalog_endpoint.py -v # single file +pytest tests/test_catalog_endpoint.py::test_name # single test + +# Lint / format (also runs on commit via pre-commit) +pre-commit run --all-files +black . # line length 120, py312 +isort . # black profile +flake8 . # max-line-length 120, config in setup.cfg + +# Docker +docker-compose up -d # uses env_file .env +``` + +The configure UI is served at `/configure`. Required env vars: `TMDB_API_KEY`, `TOKEN_SALT`, `HOST_NAME`. Redis is required (`REDIS_URL`). + +## Architecture + +### Request flow + +Every catalog request resolves through one path: + +1. **`app/services/context.py:load_user_context`** is the entry point for every authenticated endpoint. It reads the encrypted token from Redis, decrypts credentials, parses `UserSettings`, resolves a Stremio `auth_key`, and builds the `LibraryCollection`. The library is sourced from `user_settings.watch_history_source` — `"stremio"`, `"trakt"`, or `"simkl"`. For external sources the WatchHistory is converted to a `LibraryCollection` (rating ≥ 9 → loved, 7–8.9 → liked, no-rating + rewatch → loved fallback, else watched) so downstream catalog code is source-agnostic. The `LibraryCollection.source` field drives cache invalidation when a user switches sources. +2. **`app/services/recommendation/catalog_service.py`** routes the catalog ID to one of the recommendation engines: + - `watchly.rec` → `TopPicksService` (combines profile-driven Discover + library-seeded TMDB/Simkl recs) + - `watchly.theme.*` → `ThemeBasedService` (genre/keyword/era driven) + - `watchly.item.*` → `ItemBasedService` (seeded by a single library item — see "watchly.item" below) + - `watchly.creators` → `CreatorsService` (directors/cast) + - `watchly.all.loved`, `watchly.liked.all` → `AllBasedService` +3. The engine returns a list of items that are passed through metadata enrichment (`app/services/recommendation/metadata.py`), poster ratings overlay (`app/services/poster_ratings/`), translation, and serialization. + +### Taste profile pipeline (`app/services/profile/`) + +The `TasteProfile` is a numerical fingerprint of the user — top genres, keywords, directors, cast, eras, countries, runtime preference. It is built from the same source as the library: `ProfileService.build_and_cache_profile` checks the configured `watch_history_source` and feeds `WatchHistoryItem`s through the same vectorizer pipeline regardless of origin. Profiles are cached in Redis per-token-per-content-type and invalidated when the source field doesn't match. `_build_from_external_source` reuses the already-built `LibraryCollection` when its `source` matches the configured source, avoiding a duplicate Trakt/Simkl fetch. + +### External API clients + +All HTTP calls go through **`app/core/base_client.py:BaseClient`**, which provides retries (with jitter on 429/5xx), timeouts, structured error logging, and safe JSON parsing. `TraktService`, `SimklService`, and `TMDBService` are singletons that wrap `BaseClient`. The token-refresh + 401-revoke flow for Trakt/Simkl lives in `ProfileService.fetch_external_watch_history` and is shared between context loading and profile building. + +### Caching (`app/services/user_cache.py`, `app/services/redis_service.py`) + +Redis is the source of truth for user state. Per-token cached: encrypted credentials (`token_store`), library collection, taste profile (per content type), watched-id sets, library hash for incremental rebuilds. Many caches are TTL-bound (90d default for user data) and refresh on read so active users stay warm. **Invalidate library + profile on source switch**, not just on settings change — `load_user_context` and `build_and_cache_profile` both check the cached `source` field. + +### Catalog config IDs + +User catalog config IDs (in `UserSettings.catalogs`) and the IDs Stremio actually requests are different. Configs use the bare ID (`watchly.theme`, `watchly.item`); served catalogs append the seed (`watchly.theme.action`, `watchly.item.tt0468569`). `get_config_id` in `app/services/catalog_definitions.py` strips the suffix to look up settings. + +**Legacy IDs**: the previously separate `watchly.loved` and `watchly.watched` were merged into a single `watchly.item` catalog. Routing in `catalog_service.py` and `get_config_id` still accept `watchly.loved.*` / `watchly.watched.*` prefixes because installed Stremio clients keep requesting them until the manifest refreshes; `_resolve_catalog_configs` synthesizes a `watchly.item` config from any legacy entries left in saved settings. + +### Settings + catalog defaults + +`app/core/settings.py:get_default_settings()` is the single source of truth for the default catalog list and shape. Frontend pulls these via `get_default_catalogs_for_frontend()` so the configure page and backend can't drift. When adding a new catalog: add the `CatalogConfig` to defaults, add a description to `CATALOG_DESCRIPTIONS`, register routing in `app/services/recommendation/catalog_service.py`, and emit it from `DynamicCatalogService.get_dynamic_catalogs` in `app/services/catalog_definitions.py`. + +### Background work + +`app/services/catalog_updater.py` runs on a schedule (`AUTO_UPDATE_CATALOGS=true` + `CATALOG_REFRESH_INTERVAL`) to refresh dynamic catalogs ahead of user requests. Background tasks created via `asyncio.create_task` must be retained (see `app/services/catalog_updater.py:125`) — bare creates are GC-eligible and silently swallow errors. + +## Coding standards + +The codebase aims for code that reads like prose: small functions, intention-revealing names, and as little ceremony as possible. Match that. New code that is denser, more abstract, or more defensive than the surrounding files is a regression. + +- **Follow standard Python idioms.** PEP 8 spacing/naming, type hints on every public function and dataclass field, `pydantic` models for anything that crosses an API boundary, `loguru` for logging (don't import `logging`), `httpx` for HTTP (always through `BaseClient`), `async` end-to-end for I/O. No threads, no synchronous blocking calls inside async handlers. +- **Comments and docstrings: write them only when the WHY is non-obvious.** A function name and its signature should explain WHAT it does. Add a comment or docstring when there's a hidden constraint, a workaround, a subtle invariant, or behavior that would surprise the next reader (e.g. "Trakt list endpoints decode to a `list` despite the dict type hint" or "we drop the cached library on source switch because otherwise stale results are served"). Do not narrate happy-path code, do not write what-it-does docstrings, do not add `# added for X` rot. +- **Refactor when a function grows past ~40 lines or two responsibilities.** Examples already in the repo: `_build_from_external_source` was split out of `build_and_cache_profile` once dispatch logic appeared; `fetch_external_watch_history` was extracted once two call sites needed the same Trakt/Simkl flow. Don't pre-extract a helper that only has one caller. +- **No bloat.** Don't add error handling for cases that can't happen, don't validate input that's already typed, don't add backwards-compat shims unless an actual installed client depends on the old shape (Stremio manifest IDs are the main case — see legacy catalog ID handling). Three similar lines beat a premature abstraction. Delete dead code rather than leaving it with `# unused`. +- **Centralize, don't repeat.** TMDB / Trakt / Simkl calls go through their service classes, never raw `httpx`. Catalog defaults live in `get_default_settings`, not duplicated in templates. ID-prefix knowledge belongs in `get_config_id` and `_get_recommendations` routing, not scattered across modules. +- **Caches are part of the contract.** When you change the shape of something cached (LibraryCollection, TasteProfile, watched sets), think about cache invalidation. Adding a field is safe (Pydantic ignores unknowns or defaults them); changing semantics needs a versioned key or an explicit invalidate. +- **Line length 120** everywhere (black, isort, flake8 all aligned in `setup.cfg` and `pyproject.toml`). Pre-commit hooks enforce on every commit; black will reformat your file and the commit will need to be retried. + +## Commit conventions + +- **Never add a `Co-Authored-By` trailer.** Commits are authored by the human, not by the assistant. No `🤖 Generated with` lines either. +- **Stage only the files relevant to the commit** — `git add `, never `git add -A`/`git add .`. Unrelated working-tree changes (e.g. local `.gitignore` tweaks, scratch files) stay unstaged. +- **One fix per commit.** If a session produces two logically separate fixes, ship two commits so either can be reverted independently. Prefix with the area in the existing repo style: `fix(library): …`, `refactor(catalogs): …`, `feat(trakt): …`, `chore(profile): …`. + +## Domain conventions + +- **One source, one library**: never mix Stremio library items with Trakt/Simkl items in the same `LibraryCollection`. The whole collection is tagged with a single `source`. +- **Item exclusion uses both ID kinds**: `watched_imdb` (set of `tt…`) and `watched_tmdb` (set of TMDB ints). External sources only populate `watched_imdb` reliably; don't assume `watched_tmdb` is populated for Trakt/Simkl users. +- **`BaseClient.get/post` returns `dict` typed**, but JSON list responses (Trakt) decode to `list`. Defensive `_safe_list` guards in service layers handle this — preserve the pattern rather than tightening the type. diff --git a/app/api/endpoints/catalogs.py b/app/api/endpoints/catalogs.py index f1f0763..e2abd89 100644 --- a/app/api/endpoints/catalogs.py +++ b/app/api/endpoints/catalogs.py @@ -1,3 +1,5 @@ +import re + from fastapi import APIRouter, HTTPException, Response from loguru import logger @@ -6,6 +8,10 @@ router = APIRouter() +# Stremio auth tokens are short (~24 char) hex/alphanumeric strings. Accept up +# to 32 chars of [A-Za-z0-9] as a sanity check; anything else is malformed. +_TOKEN_PATTERN = re.compile(r"^[A-Za-z0-9]{1,32}$") + @router.get("/{token}/catalog/{type}/{id}.json") @router.get("/{token}/catalog/{type}/{id}/{extra}.json") @@ -13,7 +19,7 @@ async def get_catalog(response: Response, type: str, id: str, token: str, extra: if type not in ("movie", "series"): raise HTTPException(status_code=400, detail="Invalid content type. Must be 'movie' or 'series'.") - if len(token) > 30: # normal stremio tokens are 24 length. But we are using this just to be safe. + if not _TOKEN_PATTERN.match(token): raise HTTPException(status_code=400, detail="Invalid token.") try: @@ -24,8 +30,8 @@ async def get_catalog(response: Response, type: str, id: str, token: str, extra: for key, value in headers.items(): response.headers[key] = value - # if recommendations are none or empty, then set cache header to no-cache - if recommendations and not recommendations.get("meta"): + # If recommendations are empty, avoid caching the empty payload aggressively. + if recommendations is not None and not recommendations.get("metas"): response.headers["Cache-Control"] = "no-cache" return recommendations @@ -34,4 +40,4 @@ async def get_catalog(response: Response, type: str, id: str, token: str, extra: raise except Exception as e: logger.exception(f"[{redact_token(token)}] Error fetching catalog for {type}/{id}: {e}") - raise HTTPException(status_code=500, detail=f"Something went wrong. Please try again. Error: {e}") + raise HTTPException(status_code=500, detail="Something went wrong. Please try again.") diff --git a/app/api/endpoints/health.py b/app/api/endpoints/health.py index 0e339e9..0b03029 100644 --- a/app/api/endpoints/health.py +++ b/app/api/endpoints/health.py @@ -1,8 +1,9 @@ from fastapi import APIRouter +from fastapi.responses import JSONResponse router = APIRouter(tags=["health"]) @router.get("/health", summary="Simple readiness probe") -async def health_check() -> dict[str, str]: - return {"status": "ok"} +async def health_check() -> JSONResponse: + return JSONResponse(status_code=200, content={"status": "healthy"}) diff --git a/app/api/endpoints/languages.py b/app/api/endpoints/languages.py new file mode 100644 index 0000000..ea96575 --- /dev/null +++ b/app/api/endpoints/languages.py @@ -0,0 +1,33 @@ +from fastapi import APIRouter, HTTPException, Query +from loguru import logger + +from app.services.language_service import fetch_languages_list +from app.services.tmdb.service import get_tmdb_service + +router = APIRouter() + + +@router.get("/api/languages") +async def get_languages(): + try: + languages = await fetch_languages_list() + return languages + except Exception as e: + logger.error(f"Failed to fetch languages: {e}") + raise HTTPException(status_code=502, detail=f"Failed to fetch languages from TMDB: {e}") + + +@router.get("/api/meta/images") +async def get_meta_images( + media_type: str = Query(..., description="movie or tv"), + tmdb_id: int = Query(..., description="TMDB ID"), + language: str = Query("en-US", description="Language preference (e.g. en-US, fr-FR)"), +): + """Fetch language-aware poster, logo, and background images for a title.""" + try: + tmdb_service = get_tmdb_service(language=language) + images = await tmdb_service.get_images_for_title(media_type, tmdb_id, language=language) + return images + except Exception as e: + logger.error(f"Failed to fetch images for {media_type}/{tmdb_id}: {e}") + raise HTTPException(status_code=502, detail=f"Failed to fetch images from TMDB: {e}") diff --git a/app/api/endpoints/manifest.py b/app/api/endpoints/manifest.py index 30261f5..f533010 100644 --- a/app/api/endpoints/manifest.py +++ b/app/api/endpoints/manifest.py @@ -7,7 +7,6 @@ @router.get("/manifest.json") async def manifest(): - """Get base manifest for unauthenticated users.""" manifest = manifest_service.get_base_manifest() # since user is not logged in, return empty catalogs manifest["catalogs"] = [] @@ -16,5 +15,4 @@ async def manifest(): @router.get("/{token}/manifest.json") async def manifest_token(token: str): - """Get manifest for authenticated user.""" return await manifest_service.get_manifest_for_token(token) diff --git a/app/api/endpoints/meta.py b/app/api/endpoints/meta.py deleted file mode 100644 index 8e35222..0000000 --- a/app/api/endpoints/meta.py +++ /dev/null @@ -1,93 +0,0 @@ -import asyncio - -from fastapi import APIRouter, HTTPException, Query -from loguru import logger - -from app.services.tmdb.service import get_tmdb_service - -router = APIRouter() - - -async def fetch_languages_list(): - """ - Fetch and format languages list from TMDB. - Returns a list of language dictionaries with iso_639_1, language, and country. - """ - tmdb = get_tmdb_service() - tasks = [ - tmdb.get_primary_translations(), - tmdb.get_languages(), - tmdb.get_countries(), - ] - primary_translations, languages, countries = await asyncio.gather(*tasks) - - language_map = {lang["iso_639_1"]: lang["english_name"] for lang in languages} - country_map = {country["iso_3166_1"]: country["english_name"] for country in countries} - - result = [] - for element in primary_translations: - # element looks like "en-US" - parts = element.split("-") - if len(parts) != 2: - continue - - lang_code, country_code = parts - language_name = language_map.get(lang_code) - country_name = country_map.get(country_code) - - if language_name and country_name: - result.append( - { - "iso_639_1": element, - "language": language_name, - "country": country_name, - } - ) - result.sort(key=lambda x: (x["iso_639_1"] != "en-US", x["language"])) - return result - - -@router.get("/api/languages") -async def get_languages(): - try: - languages = await fetch_languages_list() - return languages - except Exception as e: - logger.error(f"Failed to fetch languages: {e}") - raise HTTPException(status_code=502, detail="Failed to fetch languages from TMDB") - - -@router.get("/api/meta/images") -async def get_meta_images( - imdb_id: str | None = Query(None, description="IMDb ID (e.g. tt1234567)"), - tmdb_id: int | None = Query(None, description="TMDB ID (use with kind)"), - kind: str = Query("movie", description="Type: movie or series"), - language: str = Query("en-US", description="Language for image preference (e.g. en-US, fr-FR)"), -): - """ - Return logo, poster and background in the requested language. - Provide either imdb_id (and optionally kind) or tmdb_id + kind. - """ - try: - tmdb = get_tmdb_service(language=language) - media_type = "tv" if kind == "series" else "movie" - - if imdb_id: - clean_imdb = imdb_id.strip().lower() - if not clean_imdb.startswith("tt"): - clean_imdb = "tt" + clean_imdb - tid, found_type = await tmdb.find_by_imdb_id(clean_imdb) - if tid is None: - raise HTTPException(status_code=404, detail="Title not found on TMDB") - media_type = found_type - tmdb_id = tid - elif tmdb_id is None: - raise HTTPException(status_code=400, detail="Provide imdb_id or tmdb_id") - - images = await tmdb.get_images_for_title(media_type, tmdb_id, language=language) - return images - except HTTPException: - raise - except Exception as e: - logger.error(f"Failed to fetch meta images: {e}") - raise HTTPException(status_code=502, detail="Failed to fetch images from TMDB") diff --git a/app/api/endpoints/oauth.py b/app/api/endpoints/oauth.py new file mode 100644 index 0000000..1b5f745 --- /dev/null +++ b/app/api/endpoints/oauth.py @@ -0,0 +1,206 @@ +import secrets +import time +from urllib.parse import urlencode + +from fastapi import APIRouter, HTTPException, Request +from fastapi.responses import HTMLResponse, RedirectResponse +from loguru import logger + +from app.core.config import settings +from app.services.simkl import simkl_service +from app.services.trakt import trakt_service + +router = APIRouter(tags=["OAuth"]) + +# Short-lived cookie name + lifetime for OAuth CSRF state. +_OAUTH_STATE_COOKIE_PREFIX = "watchly_oauth_state_" +_OAUTH_STATE_TTL_SECONDS = 600 # 10 minutes + + +def _set_state_cookie(response, provider: str, state: str) -> None: + response.set_cookie( + key=f"{_OAUTH_STATE_COOKIE_PREFIX}{provider}", + value=state, + max_age=_OAUTH_STATE_TTL_SECONDS, + httponly=True, + secure=settings.APP_ENV == "production", + samesite="lax", + path="/", + ) + + +def _verify_state(request: Request, provider: str, state: str | None) -> None: + expected = request.cookies.get(f"{_OAUTH_STATE_COOKIE_PREFIX}{provider}") + if not state or not expected or not secrets.compare_digest(state, expected): + raise HTTPException(status_code=400, detail="Invalid or missing OAuth state. Please try connecting again.") + + +# ── Trakt OAuth ────────────────────────────────────────────────────────────── + +TRAKT_AUTH_URL = "https://trakt.tv/oauth/authorize" + + +@router.get("/auth/trakt") +async def trakt_auth_redirect(request: Request): + """Redirect user to Trakt authorization page.""" + if not settings.TRAKT_CLIENT_ID: + raise HTTPException(status_code=501, detail="Trakt integration is not configured on this server.") + + redirect_uri = f"{settings.HOST_NAME}/auth/trakt/callback" + state = secrets.token_urlsafe(32) + params = urlencode( + { + "response_type": "code", + "client_id": settings.TRAKT_CLIENT_ID, + "redirect_uri": redirect_uri, + "state": state, + } + ) + response = RedirectResponse(f"{TRAKT_AUTH_URL}?{params}") + _set_state_cookie(response, "trakt", state) + return response + + +@router.get("/auth/trakt/callback", response_class=HTMLResponse) +async def trakt_callback(request: Request, code: str, state: str | None = None): + """Handle Trakt OAuth callback, exchange code for tokens.""" + if not settings.TRAKT_CLIENT_ID or not settings.TRAKT_CLIENT_SECRET: + raise HTTPException(status_code=501, detail="Trakt integration is not configured on this server.") + + _verify_state(request, "trakt", state) + + redirect_uri = f"{settings.HOST_NAME}/auth/trakt/callback" + + try: + token_data = await trakt_service.exchange_code(code, redirect_uri) + access_token = token_data.get("access_token", "") + refresh_token = token_data.get("refresh_token", "") + # Trakt returns expires_in (seconds, ~3 months) and created_at (epoch). + # Compute the absolute expiry up front so we can refresh proactively + # without re-deriving it later. + expires_in = int(token_data.get("expires_in") or 0) + created_at = int(token_data.get("created_at") or time.time()) + expires_at = created_at + expires_in if expires_in else 0 + + # Fetch username for display + user_info = await trakt_service.get_user_info(access_token) + username = user_info.get("user", {}).get("username") or user_info.get("username", "Unknown") + except Exception as e: + logger.error(f"Trakt OAuth callback failed: {e}") + return HTMLResponse(_oauth_error_page("Trakt", str(e))) + + return HTMLResponse( + _oauth_success_page( + provider="trakt", + username=username, + tokens={ + "access_token": access_token, + "refresh_token": refresh_token, + "expires_at": expires_at, + }, + ) + ) + + +# ── Simkl OAuth ────────────────────────────────────────────────────────────── + +SIMKL_AUTH_URL = "https://simkl.com/oauth/authorize" + + +@router.get("/auth/simkl") +async def simkl_auth_redirect(request: Request): + """Redirect user to Simkl authorization page.""" + if not settings.SIMKL_CLIENT_ID or not settings.SIMKL_CLIENT_SECRET: + raise HTTPException(status_code=501, detail="Simkl integration is not configured on this server.") + + redirect_uri = f"{settings.HOST_NAME}/auth/simkl/callback" + state = secrets.token_urlsafe(32) + params = urlencode( + { + "response_type": "code", + "client_id": settings.SIMKL_CLIENT_ID, + "redirect_uri": redirect_uri, + "state": state, + } + ) + response = RedirectResponse(f"{SIMKL_AUTH_URL}?{params}") + _set_state_cookie(response, "simkl", state) + return response + + +@router.get("/auth/simkl/callback", response_class=HTMLResponse) +async def simkl_callback(request: Request, code: str, state: str | None = None): + """Handle Simkl OAuth callback, exchange code for tokens.""" + if not settings.SIMKL_CLIENT_ID or not settings.SIMKL_CLIENT_SECRET: + raise HTTPException(status_code=501, detail="Simkl integration is not configured on this server.") + + _verify_state(request, "simkl", state) + + redirect_uri = f"{settings.HOST_NAME}/auth/simkl/callback" + + try: + token_data = await simkl_service.exchange_code( + code, + redirect_uri, + settings.SIMKL_CLIENT_ID, + settings.SIMKL_CLIENT_SECRET, + ) + access_token = token_data.get("access_token", "") + + user_info = await simkl_service.get_user_settings(access_token, settings.SIMKL_CLIENT_ID) + username = user_info.get("user", {}).get("name") or user_info.get("account", {}).get("id", "Unknown") + except Exception as e: + logger.error(f"Simkl OAuth callback failed: {e}") + return HTMLResponse(_oauth_error_page("Simkl", str(e))) + + return HTMLResponse( + _oauth_success_page( + provider="simkl", + username=str(username), + tokens={"access_token": access_token}, + ) + ) + + +# ── HTML helpers ───────────────────────────────────────────────────────────── + + +def _oauth_success_page(provider: str, username: str, tokens: dict[str, str]) -> str: + """Generate a callback page that sends tokens back to the opener window.""" + import html + import json + from urllib.parse import urlparse + + safe_username = html.escape(username or "") + safe_provider = html.escape(provider.title()) + payload = json.dumps({"provider": provider, "username": username, "tokens": tokens}) + # Pin postMessage target to the configured app origin so we never broadcast + # tokens to whatever origin the popup happens to be on. + parsed = urlparse(settings.HOST_NAME or "") + target_origin = json.dumps(f"{parsed.scheme}://{parsed.netloc}" if parsed.scheme and parsed.netloc else "/") + return f""" + {safe_provider} Connected + +

Connected as {safe_username}

+

You can close this window.

+ + """ + + +def _oauth_error_page(provider: str, error: str) -> str: + import html + + safe_provider = html.escape(provider.title()) + safe_error = html.escape(error or "") + return f""" +{safe_provider} Error + +

{safe_provider} login failed

+

{safe_error}

+

Please close this window and try again.

+""" diff --git a/app/api/endpoints/tokens.py b/app/api/endpoints/tokens.py index bea80c0..958bf7b 100644 --- a/app/api/endpoints/tokens.py +++ b/app/api/endpoints/tokens.py @@ -1,229 +1,67 @@ -from datetime import datetime, timezone -from typing import Literal - -from fastapi import APIRouter, HTTPException, Request +from fastapi import APIRouter, HTTPException +from fastapi.responses import JSONResponse from loguru import logger -from pydantic import BaseModel, Field -from app.core.config import settings +from app.api.models.tokens import TokenRequest, TokenResponse from app.core.security import redact_token -from app.core.settings import CatalogConfig, PosterRatingConfig, UserSettings, get_default_settings -from app.services.manifest import manifest_service +from app.services.auth import auth_service from app.services.stremio.service import StremioBundle -from app.services.token_store import token_store - -router = APIRouter(prefix="/tokens", tags=["tokens"]) - - -class TokenRequest(BaseModel): - authKey: str | None = Field(default=None, description="Stremio auth key") - email: str | None = Field(default=None, description="Stremio account email") - password: str | None = Field(default=None, description="Stremio account password (stored securely)") - catalogs: list[CatalogConfig] | None = Field(default=None, description="Optional catalog configuration") - language: str = Field(default="en-US", description="Language for TMDB API") - poster_rating: PosterRatingConfig | None = Field(default=None, description="Poster rating provider configuration") - excluded_movie_genres: list[str] = Field(default_factory=list, description="List of movie genre IDs to exclude") - excluded_series_genres: list[str] = Field(default_factory=list, description="List of series genre IDs to exclude") - popularity: Literal["mainstream", "balanced", "gems", "all"] = Field( - default="balanced", description="Popularity for TMDB API" - ) - year_min: int = Field(default=2010, description="Minimum release year for TMDB API") - year_max: int = Field(default=2025, description="Maximum release year for TMDB API") - sorting_order: Literal["default", "movies_first", "series_first"] = Field( - default="default", description="Order of movies and series catalogs" - ) - simkl_api_key: str | None = Field(default=None, description="Simkl API Key for the user") - gemini_api_key: str | None = Field(default=None, description="Gemini API Key for AI features") - tmdb_api_key: str | None = Field( - default=None, description="TMDB API Key (required for new clients if server has none)" - ) - - -class TokenResponse(BaseModel): - token: str - manifestUrl: str - expiresInSeconds: int | None = Field( - default=None, - description="Number of seconds before the token expires (None means it does not expire)", - ) +router = APIRouter(prefix="/tokens", tags=["Tokens"]) -async def _verify_credentials_or_raise(bundle: StremioBundle, auth_key: str) -> str: - """Ensure the supplied auth key is valid.""" - try: - await bundle.auth.get_user_info(auth_key) - return auth_key - except Exception as exc: - raise HTTPException( - status_code=400, - detail="Invalid Stremio auth key.", - ) from exc - - -@router.post("/", response_model=TokenResponse) -async def create_token(payload: TokenRequest, request: Request) -> TokenResponse: - # Prefer email+password if provided; else require authKey - email = (payload.email or "").strip() or None - password = (payload.password or "").strip() or None - stremio_auth_key = (payload.authKey or "").strip() or None - if not (email and password) and not stremio_auth_key: - raise HTTPException(status_code=400, detail="Provide email+password or a valid Stremio auth key.") - - # Remove quotes if present for authKey - if stremio_auth_key and stremio_auth_key.startswith('"') and stremio_auth_key.endswith('"'): - stremio_auth_key = stremio_auth_key[1:-1].strip() +async def _trigger_initial_caching(auth_key: str, user_settings, token: str) -> None: + """Cache library and profiles after token creation. Failures are non-blocking.""" + from app.services.manifest import manifest_service bundle = StremioBundle() - # 1. Establish a valid auth key and fetch user info - if email and password: - stremio_auth_key = await bundle.auth.login(email, password) - - try: - user_info = await bundle.auth.get_user_info(stremio_auth_key) - user_id = user_info["user_id"] - resolved_email = user_info.get("email", "") - except Exception as e: - raise HTTPException(status_code=400, detail=f"Failed to verify Stremio identity: {e}") - - # 2. Check if user already exists - token = token_store.get_token_from_user_id(user_id) - existing_data = await token_store.get_user_data(token) - - # 3. Construct Settings - default_settings = get_default_settings() - poster_rating = payload.poster_rating - user_settings = UserSettings( - language=payload.language or default_settings.language, - catalogs=payload.catalogs if payload.catalogs else default_settings.catalogs, - poster_rating=poster_rating, - excluded_movie_genres=payload.excluded_movie_genres, - excluded_series_genres=payload.excluded_series_genres, - year_min=payload.year_min, - year_max=payload.year_max, - popularity=payload.popularity, - sorting_order=payload.sorting_order, - simkl_api_key=payload.simkl_api_key, - gemini_api_key=payload.gemini_api_key, - tmdb_api_key=payload.tmdb_api_key, - ) - - # 4. Prepare payload to store - payload_to_store = { - "authKey": stremio_auth_key, - "email": resolved_email or email or "", - "settings": user_settings.model_dump(), - } - if existing_data: - payload_to_store["last_updated"] = existing_data.get("last_updated") - else: - payload_to_store["last_updated"] = datetime.now(timezone.utc).isoformat() - - if email and password: - payload_to_store["password"] = password - - # 5. Store user data - token = await token_store.store_user_data(user_id, payload_to_store) - account_status = "updated" if existing_data else "created" - logger.info(f"[{redact_token(token)}] Account {account_status} for user {user_id}") - - # 6. Cache library items and profiles before returning - # This ensures manifest generation is fast when user installs the addon - # We wait for caching to complete so everything is ready immediately try: logger.info(f"[{redact_token(token)}] Caching library and profiles before returning token") - await manifest_service.cache_library_and_profiles(bundle, stremio_auth_key, user_settings, token) + await manifest_service.cache_library_and_profiles(bundle, auth_key, user_settings, token) logger.info(f"[{redact_token(token)}] Successfully cached library and profiles") except Exception as e: logger.warning( f"[{redact_token(token)}] Failed to cache library and profiles: {e}. " "Continuing anyway - will cache on manifest request." ) - # Continue even if caching fails - manifest service will handle it - - base_url = settings.HOST_NAME - manifest_url = f"{base_url}/{token}/manifest.json" - expires_in = settings.TOKEN_TTL_SECONDS if settings.TOKEN_TTL_SECONDS > 0 else None - - await bundle.close() - - return TokenResponse( - token=token, - manifestUrl=manifest_url, - expiresInSeconds=expires_in, - ) + finally: + await bundle.close() -async def get_stremio_user_data(payload: TokenRequest) -> tuple[str, str]: - bundle = StremioBundle() +@router.post("/", response_model=TokenResponse) +async def create_token(payload: TokenRequest) -> TokenResponse: try: - email = (payload.email or "").strip() or None - password = (payload.password or "").strip() or None - auth_key = (payload.authKey or "").strip() or None - - if email and password: - try: - auth_key = await bundle.auth.login(email, password) - user_info = await bundle.auth.get_user_info(auth_key) - return user_info["user_id"], user_info.get("email", email) - except Exception as e: - logger.error(f"Stremio identity check failed: {e}") - raise HTTPException(status_code=400, detail="Failed to verify Stremio identity.") - elif auth_key: - if auth_key.startswith('"') and auth_key.endswith('"'): - auth_key = auth_key[1:-1].strip() - try: - user_info = await bundle.auth.get_user_info(auth_key) - return user_info["user_id"], user_info.get("email", "") - except Exception as e: - logger.error(f"Stremio identity check failed: {e}") - raise HTTPException(status_code=400, detail="Invalid Stremio auth key.") - else: - raise HTTPException(status_code=400, detail="Credentials required.") - finally: - await bundle.close() + response, auth_key, user_settings = await auth_service.create_user_token(payload) + await _trigger_initial_caching(auth_key, user_settings, response.token) + return response + except HTTPException: + raise + except Exception as exc: + logger.exception(f"Token creation failed: {exc}") + raise HTTPException(status_code=503, detail="Storage temporarily unavailable.") @router.post("/stremio-identity", status_code=200) async def check_stremio_identity(payload: TokenRequest): - """Fetch user info from Stremio and check if account exists.""" - user_id, email = await get_stremio_user_data(payload) try: - token = token_store.get_token_from_user_id(user_id) - user_data = await token_store.get_user_data(token) - exists = bool(user_data) - except Exception: - exists = False - user_data = None - - response = {"user_id": user_id, "email": email, "exists": exists} - if exists and user_data: - # Reconstruct UserSettings to ensure defaults (like sorting_order) are included for old accounts - raw_settings = user_data.get("settings", {}) - try: - user_settings = UserSettings(**raw_settings) - response["settings"] = user_settings.model_dump() - except Exception as e: - logger.warning(f"Failed to normalize settings for user {user_id}: {e}") - response["settings"] = raw_settings - return response + return await auth_service.get_identity_with_settings(payload) + except HTTPException: + raise + except Exception as exc: + logger.exception(f"Identity check failed: {exc}") + raise HTTPException(status_code=503, detail="Service temporarily unavailable.") @router.delete("/", status_code=200) async def delete_redis_token(payload: TokenRequest): - """Delete a token based on Stremio credentials.""" try: - user_id, _ = await get_stremio_user_data(payload) - token = token_store.get_token_from_user_id(user_id) - existing_data = await token_store.get_user_data(token) - if not existing_data: - raise HTTPException(status_code=404, detail="Account not found.") - - await token_store.delete_token(token) - logger.info(f"[{redact_token(token)}] Token deleted for user {user_id}") - return {"detail": "Settings deleted successfully"} + await auth_service.delete_user_account(payload) + return JSONResponse( + status_code=200, + content={"status": "ok", "message": "Settings deleted successfully"}, + ) except HTTPException: raise except Exception as exc: - logger.error(f"Token deletion failed: {exc}") - raise HTTPException(status_code=503, detail="Storage temporarily unavailable.") + logger.exception(f"Account deletion failed: {exc}") + raise HTTPException(status_code=503, detail="Service temporarily unavailable.") diff --git a/app/api/endpoints/validation.py b/app/api/endpoints/validation.py index e324116..e3cf797 100644 --- a/app/api/endpoints/validation.py +++ b/app/api/endpoints/validation.py @@ -1,11 +1,13 @@ from fastapi import APIRouter, HTTPException from google import genai from loguru import logger +from pydantic import BaseModel, Field from app.api.models.validation import BaseValidationInput, BaseValidationResponse, PosterRatingValidationInput from app.services.poster_ratings.factory import PosterProvider, poster_ratings_factory from app.services.simkl import simkl_service from app.services.tmdb.client import TMDBClient +from app.services.trakt import trakt_service router = APIRouter(tags=["Validation"]) @@ -54,9 +56,11 @@ async def validate_poster_rating_api_key(payload: PosterRatingValidationInput) - if is_valid: return BaseValidationResponse(valid=True, message="API key is valid") return BaseValidationResponse(valid=False, message="Invalid API key") + except HTTPException: + raise except Exception as e: - logger.error(f"Validation failed: {str(e)}") - raise HTTPException(status_code=500, detail="Validation failed due to an internal error.") + logger.error(f"Poster rating validation failed: {str(e)}") + return BaseValidationResponse(valid=False, message="Could not validate API key. Please try again.") @router.post("/simkl/validation") @@ -67,5 +71,49 @@ async def validate_simkl_api_key(data: BaseValidationInput) -> BaseValidationRes return BaseValidationResponse(valid=True, message="Valid API Key") return BaseValidationResponse(valid=False, message="Invalid API Key") except Exception as e: - logger.error(f"Validation failed: {str(e)}") - raise HTTPException(status_code=500, detail="Validation failed due to an internal error.") + logger.error(f"Simkl validation failed: {str(e)}") + return BaseValidationResponse(valid=False, message="Could not validate API key. Please try again.") + + +class OAuthTokenValidationInput(BaseModel): + access_token: str = Field(description="OAuth access token to validate") + + +@router.post("/trakt/validation") +async def validate_trakt_token(data: OAuthTokenValidationInput) -> BaseValidationResponse: + """Validate a Trakt OAuth access token by calling /users/me.""" + try: + user_info = await trakt_service.get_user_info(data.access_token) + username = user_info.get("user", {}).get("username") or user_info.get("username", "") + return BaseValidationResponse(valid=True, message=f"Connected as {username}") + except Exception as e: + logger.debug(f"Trakt token validation failed: {e}") + return BaseValidationResponse(valid=False, message="Invalid or expired Trakt token") + + +@router.post("/simkl-sync/validation") +async def validate_simkl_sync_token(data: OAuthTokenValidationInput) -> BaseValidationResponse: + """Validate a Simkl OAuth access token.""" + from app.core.config import settings as app_settings + + if not app_settings.SIMKL_CLIENT_ID: + return BaseValidationResponse(valid=False, message="Simkl integration is not configured on this server") + try: + from httpx import AsyncClient + + async with AsyncClient(timeout=10) as client: + resp = await client.get( + "https://api.simkl.com/users/settings", + headers={ + "Authorization": f"Bearer {data.access_token}", + "simkl-api-key": app_settings.SIMKL_CLIENT_ID, + }, + follow_redirects=True, + ) + resp.raise_for_status() + user_info = resp.json() + username = user_info.get("user", {}).get("name") or "Unknown" + return BaseValidationResponse(valid=True, message=f"Connected as {username}") + except Exception as e: + logger.debug(f"Simkl sync token validation failed: {e}") + return BaseValidationResponse(valid=False, message="Invalid or expired Simkl token") diff --git a/app/api/models/tokens.py b/app/api/models/tokens.py new file mode 100644 index 0000000..0d20853 --- /dev/null +++ b/app/api/models/tokens.py @@ -0,0 +1,45 @@ +from typing import Literal + +from pydantic import BaseModel, Field + +from app.core.settings import DEFAULT_YEAR_MIN, CatalogConfig, PosterRatingConfig, get_default_year_max + + +class TokenRequest(BaseModel): + authKey: str | None = Field(default=None, description="Stremio auth key") + email: str | None = Field(default=None, description="Stremio account email") + password: str | None = Field(default=None, description="Stremio account password") + catalogs: list[CatalogConfig] | None = Field(default=None, description="Catalog configuration") + language: str = Field(default="en-US", description="Language for TMDB API") + poster_rating: PosterRatingConfig | None = Field(default=None, description="Poster rating provider configuration") + excluded_movie_genres: list[str] = Field(default_factory=list, description="List of movie genre IDs to exclude") + excluded_series_genres: list[str] = Field(default_factory=list, description="List of series genre IDs to exclude") + popularity: Literal["mainstream", "balanced", "gems", "all"] = Field( + default="balanced", description="Popularity for TMDB API" + ) + year_min: int = Field(default=DEFAULT_YEAR_MIN, description="Minimum release year for TMDB API") + year_max: int = Field(default_factory=get_default_year_max, description="Maximum release year for TMDB API") + sorting_order: Literal["default", "movies_first", "series_first"] = Field( + default="default", description="Order of movies and series catalogs" + ) + simkl_api_key: str | None = Field(default=None, description="Simkl API Key for the user") + gemini_api_key: str | None = Field(default=None, description="Gemini API Key for AI features") + tmdb_api_key: str | None = Field(default=None, description="TMDB API Key") + trakt_access_token: str | None = Field(default=None, description="Trakt OAuth access token") + trakt_refresh_token: str | None = Field(default=None, description="Trakt OAuth refresh token") + trakt_token_expires_at: int | None = Field( + default=None, description="Epoch seconds when the Trakt access token expires" + ) + simkl_access_token: str | None = Field(default=None, description="Simkl OAuth access token") + watch_history_source: Literal["stremio", "trakt", "simkl"] = Field( + default="stremio", description="Source for watch history" + ) + + +class TokenResponse(BaseModel): + token: str + manifestUrl: str + expiresInSeconds: int | None = Field( + default=None, + description="Number of seconds before the token expires (None means it does not expire)", + ) diff --git a/app/api/router.py b/app/api/router.py index 72c99bc..f30acfc 100644 --- a/app/api/router.py +++ b/app/api/router.py @@ -3,8 +3,9 @@ from .endpoints.announcement import router as announcement_router from .endpoints.catalogs import router as catalogs_router from .endpoints.health import router as health_router +from .endpoints.languages import router as language_router from .endpoints.manifest import router as manifest_router -from .endpoints.meta import router as meta_router +from .endpoints.oauth import router as oauth_router from .endpoints.stats import router as stats_router from .endpoints.tokens import router as tokens_router from .endpoints.validation import router as validation_router @@ -21,7 +22,8 @@ async def root(): api_router.include_router(catalogs_router) api_router.include_router(tokens_router) api_router.include_router(health_router) -api_router.include_router(meta_router) +api_router.include_router(language_router) api_router.include_router(announcement_router) api_router.include_router(stats_router) api_router.include_router(validation_router) +api_router.include_router(oauth_router) diff --git a/app/core/app.py b/app/core/app.py index 4ba9bf7..55638d2 100644 --- a/app/core/app.py +++ b/app/core/app.py @@ -9,9 +9,9 @@ from jinja2 import Environment, FileSystemLoader from loguru import logger -from app.api.endpoints.meta import fetch_languages_list +from app.api.endpoints.languages import fetch_languages_list from app.api.router import api_router -from app.core.settings import get_default_catalogs_for_frontend +from app.core.settings import get_current_year, get_default_catalogs_for_frontend, get_default_year_range from app.services.redis_service import redis_service from app.services.tmdb.genre import movie_genres, series_genres from app.services.token_store import token_store @@ -30,10 +30,10 @@ async def lifespan(app: FastAPI): Manage application lifespan events (startup/shutdown). """ # Startup checks - if settings.TOKEN_SALT == "change-me" and settings.APP_ENV == "production": - logger.warning( - "Security Warning: TOKEN_SALT is set to default 'change-me' in production environment! " - "Please set the TOKEN_SALT environment variable." + if settings.APP_ENV == "production" and (not settings.TOKEN_SALT or settings.TOKEN_SALT == "change-me"): + raise RuntimeError( + "TOKEN_SALT is unset or using the insecure default 'change-me' in production. " + "Set the TOKEN_SALT environment variable to a strong, unique value before starting the app." ) yield @@ -89,6 +89,7 @@ async def configure_page(request: Request, _token: str | None = None): # Format default catalogs for frontend default_catalogs = get_default_catalogs_for_frontend() + year_range_defaults = get_default_year_range() # Format genres for frontend movie_genres_list = [{"id": str(id), "name": name} for id, name in movie_genres.items()] @@ -103,6 +104,8 @@ async def configure_page(request: Request, _token: str | None = None): announcement_html=settings.ANNOUNCEMENT_HTML or "", languages=languages, default_catalogs=default_catalogs, + current_year=get_current_year(), + year_range_defaults=year_range_defaults, movie_genres=movie_genres_list, series_genres=series_genres_list, ) diff --git a/app/core/base_client.py b/app/core/base_client.py index a76362b..3ef6cd4 100644 --- a/app/core/base_client.py +++ b/app/core/base_client.py @@ -1,4 +1,5 @@ import asyncio +import random from typing import Any import httpx @@ -53,7 +54,9 @@ async def _request(self, method: str, url: str, max_tries: int | None = None, ** is_retryable = e.response.status_code in (429, 500, 502, 503, 504) if is_retryable and attempt < tries: - wait_time = 0.5 * (2 ** (attempt - 1)) # Exponential backoff + # Exponential backoff + small random jitter to avoid retry + # stampedes when many concurrent users hit the same 429. + wait_time = 0.5 * (2 ** (attempt - 1)) + random.uniform(0, 0.25) logger.warning( f"Request failed ({method} {url}): {str(e)}. " f"Retrying in {wait_time}s... (Attempt {attempt}/{tries})" @@ -69,12 +72,23 @@ async def _request(self, method: str, url: str, max_tries: int | None = None, ** raise httpx.RequestError(f"Request failed for {method} {url} with 0 attempts configured") + @staticmethod + def _safe_json(response: httpx.Response, method: str, url: str) -> dict[str, Any]: + """Parse JSON body, returning {} on empty/non-JSON 2xx responses.""" + if not response.content: + return {} + try: + return response.json() + except ValueError as e: + logger.warning(f"Non-JSON body from {method} {url} (status={response.status_code}): {e}") + return {} + async def get(self, url: str, params: dict[str, Any] | None = None, **kwargs) -> dict[str, Any]: """Perform a GET request and return the JSON response.""" response = await self._request("GET", url, params=params, **kwargs) - return response.json() + return self._safe_json(response, "GET", url) async def post(self, url: str, json: dict[str, Any] | None = None, **kwargs) -> dict[str, Any]: """Perform a POST request and return the JSON response.""" response = await self._request("POST", url, json=json, **kwargs) - return response.json() + return self._safe_json(response, "POST", url) diff --git a/app/core/config.py b/app/core/config.py index cc19ad1..c0a20fc 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -41,6 +41,12 @@ class Settings(BaseSettings): CATALOG_CACHE_TTL: int = 43200 # 12 hours CATALOG_STALE_TTL: int = 604800 # 7 days (soft expiration fallback) + # External history providers (OAuth app credentials) + TRAKT_CLIENT_ID: str | None = None + TRAKT_CLIENT_SECRET: str | None = None + SIMKL_CLIENT_ID: str | None = None + SIMKL_CLIENT_SECRET: str | None = None + # AI DEFAULT_GEMINI_MODEL: str = "gemma-4-26b-a4b-it" GEMINI_API_KEY: str | None = None diff --git a/app/core/constants.py b/app/core/constants.py index 96cbed3..797c282 100644 --- a/app/core/constants.py +++ b/app/core/constants.py @@ -1,5 +1,4 @@ RECOMMENDATIONS_CATALOG_NAME: str = "Top Picks For You" -DEFAULT_MIN_ITEMS: int = 8 DEFAULT_CATALOG_LIMIT: int = 20 MAX_CATALOG_ITEMS: int = 100 @@ -16,6 +15,13 @@ WATCHED_SETS_KEY: str = "watchly:watched_sets:{token}:{content_type}" CATALOG_KEY: str = "watchly:catalog:{token}:{type}:{id}" +# Bounded TTL for per-user caches (library items, profile, watched sets, +# library hash, last-build timestamp). Refreshed on every read so an active +# user's data effectively never expires, but a stale install gets cleaned up +# by Redis instead of growing forever. The user's main token key is NOT +# subject to this — that follows TOKEN_TTL_SECONDS. +USER_CACHE_TTL_SECONDS: int = 60 * 60 * 24 * 90 # 90 days + DISCOVER_ONLY_EXTRA: list[dict] = [{"name": "genre", "isRequired": True, "options": ["All"], "optionsLimit": 1}] diff --git a/app/core/settings.py b/app/core/settings.py index 01ae6f0..f7eda56 100644 --- a/app/core/settings.py +++ b/app/core/settings.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Literal from pydantic import BaseModel, Field @@ -21,7 +22,25 @@ class PosterRatingConfig(BaseModel): provider: Literal[PosterProvider.RPDB.value, PosterProvider.TOP_POSTERS.value] = Field( description="Provider name: 'rpdb' or 'top_posters'" ) - api_key: str = Field(description="API key for the provider") + api_key: str | None = Field(default=None, description="API key for the provider") + + +def get_current_year() -> int: + return datetime.now().year + + +DEFAULT_YEAR_MIN = 1970 + + +def get_default_year_max() -> int: + return get_current_year() + + +def get_default_year_range() -> dict[str, int]: + return { + "min": DEFAULT_YEAR_MIN, + "max": get_default_year_max(), + } class UserSettings(BaseModel): @@ -30,8 +49,8 @@ class UserSettings(BaseModel): poster_rating: PosterRatingConfig | None = Field(default=None, description="Poster rating provider configuration") excluded_movie_genres: list[str] = Field(default_factory=list) excluded_series_genres: list[str] = Field(default_factory=list) - year_min: int = Field(default=1970, description="Minimum release year") - year_max: int = Field(default=2026, description="Maximum release year") + year_min: int = Field(default=DEFAULT_YEAR_MIN, description="Minimum release year") + year_max: int = Field(default_factory=get_default_year_max, description="Maximum release year") popularity: Literal["mainstream", "balanced", "gems", "all"] = Field( default="balanced", description="Popularity preference" ) @@ -41,24 +60,30 @@ class UserSettings(BaseModel): simkl_api_key: str | None = Field(default=None, description="Simkl API Key for the user") gemini_api_key: str | None = Field(default=None, description="Gemini API Key for AI-powered features") tmdb_api_key: str | None = Field(default=None, description="TMDB API Key (used if set; else server config)") + trakt_access_token: str | None = Field(default=None, description="Trakt OAuth access token") + trakt_refresh_token: str | None = Field(default=None, description="Trakt OAuth refresh token") + trakt_token_expires_at: int | None = Field( + default=None, description="Epoch seconds when the Trakt access token expires" + ) + simkl_access_token: str | None = Field(default=None, description="Simkl OAuth access token") + watch_history_source: Literal["stremio", "trakt", "simkl"] = Field( + default="stremio", description="Source for watch history used in profile building" + ) # Catalog descriptions for frontend CATALOG_DESCRIPTIONS = { "watchly.rec": "Personalized recommendations based on your watch history, library and your reactions.", - "watchly.loved": ( - "Recommends items similar to the content you recently loved. example: If you loved 'The Dark Knight'," - " Then it will show similar items to 'The Dark Knight'. This takes your last 3 loved items and shuffles" - " them and picks one at random." - ), - "watchly.watched": ( - "Recommends items similar to the content you recently watched. example: If you watched 'The Dark" - " Knight', Then it will show similar items to 'The Dark Knight'. This takes your last 3 watched items" - " and shuffles them and picks one at random." + "watchly.item": ( + "Recommends items similar to one you recently watched or loved. The seed is picked uniformly at random" + " from a pool of your 3 most-recent loved items + your 3 most-recent watched items. The catalog title" + " becomes 'Because you loved ' or 'Because you watched <title>' depending on which bucket the" + " seed came from." ), "watchly.creators": ( - "Recommends items from your top 5 favorite directors and top 5 favorite actors.(Favourite = Most" - " watched items)" + "Recommends items from your recurring directors and lead actors — those who appear across multiple" + " items in your library, not just one. Single-appearance creators are filtered out so the catalog" + " actually reflects who you keep coming back to." ), "watchly.all.loved": "Recommendations based on all your loved items", "watchly.liked.all": "Recommendations based on all your liked items", @@ -84,17 +109,8 @@ def get_default_settings() -> UserSettings: shuffle=False, ), CatalogConfig( - id="watchly.loved", - name="More Like", - enabled=True, - enabled_movie=True, - enabled_series=True, - display_at_home=True, - shuffle=False, - ), - CatalogConfig( - id="watchly.watched", - name="Because you watched", + id="watchly.item", + name="Because you Watched/Loved", enabled=True, enabled_movie=True, enabled_series=True, diff --git a/app/core/version.py b/app/core/version.py index a1078c1..afc83c5 100644 --- a/app/core/version.py +++ b/app/core/version.py @@ -1 +1 @@ -__version__ = "1.9.7" +__version__ = "1.10.0-rc.4" diff --git a/app/models/history.py b/app/models/history.py new file mode 100644 index 0000000..391afad --- /dev/null +++ b/app/models/history.py @@ -0,0 +1,27 @@ +from datetime import datetime +from typing import Literal + +from pydantic import BaseModel, Field + + +class WatchHistoryItem(BaseModel): + """Unified watch history item from any source (Stremio, Trakt, Simkl).""" + + imdb_id: str # tt1234567 + type: str # "movie" | "series" + name: str + rating: float | None = None # 1-10 explicit rating (None = unrated) + watch_count: int = 1 + completion: float = 1.0 # 0.0-1.0 (fraction of content watched) + last_watched: datetime | None = None + source: Literal["stremio", "trakt", "simkl"] = "stremio" + + +class WatchHistory(BaseModel): + """Collection of watch history items from a single source.""" + + items: list[WatchHistoryItem] = Field(default_factory=list) + source: Literal["stremio", "trakt", "simkl"] = "stremio" + + def imdb_ids(self) -> set[str]: + return {i.imdb_id for i in self.items} diff --git a/app/models/library.py b/app/models/library.py new file mode 100644 index 0000000..76c2b19 --- /dev/null +++ b/app/models/library.py @@ -0,0 +1,89 @@ +from datetime import datetime + +from pydantic import BaseModel, Field, field_validator + + +class StremioState(BaseModel): + """Represents the user state for a library item.""" + + lastWatched: datetime | None = None + timeWatched: int = 0 + timeOffset: int = 0 + overallTimeWatched: int = 0 + timesWatched: int = 0 + flaggedWatched: int = 0 + duration: int = 0 + video_id: str | None = None + watched: str | None = None + noNotif: bool = False + season: int = 0 + episode: int = 0 + + @field_validator("lastWatched", mode="before") + @classmethod + def parse_last_watched(cls, v): + if isinstance(v, str): + try: + return datetime.fromisoformat(v.replace("Z", "+00:00")) + except ValueError: + return None + return v + + +class StremioLibraryItem(BaseModel): + """Represents a raw item from Stremio library.""" + + id: str = Field(..., alias="_id") + type: str + name: str + state: StremioState = Field(default_factory=StremioState) + mtime: str = Field(default="", alias="_mtime") + poster: str | None = None + temp: bool + removed: bool + + # Enriched fields (not in raw Stremio JSON, added by our service) + is_loved: bool = Field(default=False, alias="_is_loved") + is_liked: bool = Field(default=False, alias="_is_liked") + interest_score: float = Field(default=0.0, alias="_interest_score") + + class Config: + populate_by_name = True + + +class LibraryCollection(BaseModel): + """Typed container for categorized library items. + + This is the single shape that flows through the app. When Trakt/Simkl + history providers are added, they produce the same LibraryCollection + so the rest of the app doesn't care about the source. + """ + + loved: list[StremioLibraryItem] = [] + liked: list[StremioLibraryItem] = [] + watched: list[StremioLibraryItem] = [] + added: list[StremioLibraryItem] = [] + removed: list[StremioLibraryItem] = [] + source: str = "stremio" + + def all_items(self) -> list[StremioLibraryItem]: + return self.loved + self.liked + self.watched + self.added + + def all_items_with_removed(self) -> list[StremioLibraryItem]: + return self.loved + self.liked + self.watched + self.added + self.removed + + def for_type(self, content_type: str) -> "LibraryCollection": + return LibraryCollection( + loved=[i for i in self.loved if i.type == content_type], + liked=[i for i in self.liked if i.type == content_type], + watched=[i for i in self.watched if i.type == content_type], + added=[i for i in self.added if i.type == content_type], + removed=[i for i in self.removed if i.type == content_type], + source=self.source, + ) + + def all_imdb_ids(self) -> set[str]: + return {i.id for i in self.all_items_with_removed() if i.id.startswith("tt")} + + def is_empty(self) -> bool: + return not any([self.loved, self.liked, self.watched, self.added]) diff --git a/app/models/taste_profile.py b/app/models/profile.py similarity index 75% rename from app/models/taste_profile.py rename to app/models/profile.py index e144b67..06c8bf0 100644 --- a/app/models/taste_profile.py +++ b/app/models/profile.py @@ -3,10 +3,25 @@ from pydantic import BaseModel, Field +from app.models.library import StremioLibraryItem -class TasteProfile(BaseModel): + +class ScoredItem(BaseModel): + """A processed library item with calculated interest scores. + + Output of the ScoringService — used by the profile builder and sampler. """ - Transparent, additive taste profile. + + item: StremioLibraryItem + score: float + completion_rate: float + is_rewatched: bool + is_recent: bool + source_type: str # 'loved' | 'watched' | 'liked' + + +class TasteProfile(BaseModel): + """Transparent, additive taste profile. Answers one question: "Which item is more likely to be liked by this user?" @@ -23,6 +38,13 @@ class TasteProfile(BaseModel): country_scores: dict[str, float] = Field(default_factory=dict, description="Country code → accumulated score") director_scores: dict[int, float] = Field(default_factory=dict, description="Director ID → accumulated score") cast_scores: dict[int, float] = Field(default_factory=dict, description="Actor ID → accumulated score") + # Raw appearance counts kept alongside the score-based dicts above. Scores + # answer "how strong is this creator's signal"; frequencies answer "across + # how many items did the user actually engage with this creator". The + # creators catalog uses the frequency view to filter out single-appearance + # noise that scores alone can't separate. + director_frequency: dict[int, int] = Field(default_factory=dict, description="Director ID → item appearance count") + cast_frequency: dict[int, int] = Field(default_factory=dict, description="Actor ID → item appearance count") runtime_bucket_scores: dict[str, float] = Field( default_factory=dict, description="Runtime bucket (short/medium/long) → accumulated score", @@ -38,11 +60,12 @@ class TasteProfile(BaseModel): default_factory=set, description="Set of processed item IDs to prevent double counting", ) - interest_summary: str | None = Field(default=None, description="LLM-generated description of user interests") + # Which watch-history source produced this profile. Used to invalidate + # cached profiles when the user changes watch_history_source in settings, + # so a Stremio-built profile isn't served after switching to Trakt/Simkl. + source: str = Field(default="stremio", description="stremio | trakt | simkl") class Config: - """Pydantic configuration.""" - json_encoders = {datetime: lambda v: v.isoformat()} def get_top_genres(self, limit: int = 5) -> list[tuple[int, float]]: @@ -70,18 +93,12 @@ def get_top_cast(self, limit: int = 5) -> list[tuple[int, float]]: return sorted(self.cast_scores.items(), key=lambda x: x[1], reverse=True)[:limit] def get_top_creators(self, limit: int = 5) -> list[tuple[int, float]]: - """ - Get top N creators (directors + cast merged) by score. - - Runtime merge for convenience. Profile stores them separately. - """ - # Merge directors and cast for combined ranking + """Get top N creators (directors + cast merged) by score.""" all_creators = {**self.director_scores, **self.cast_scores} return sorted(all_creators.items(), key=lambda x: x[1], reverse=True)[:limit] def normalize_for_ranking(self) -> dict[str, dict[Any, float]]: - """ - Normalize scores for ranking (read-time only). + """Normalize scores for ranking (read-time only). Returns normalized scores (0-1 range) for each feature type. Used only when generating recommendations, never during profile updates. diff --git a/app/models/scoring.py b/app/models/scoring.py deleted file mode 100644 index 6cc007e..0000000 --- a/app/models/scoring.py +++ /dev/null @@ -1,65 +0,0 @@ -from datetime import datetime - -from pydantic import BaseModel, Field, field_validator - - -class StremioState(BaseModel): - """Represents the user state for a library item.""" - - lastWatched: datetime | None = None - timeWatched: int = 0 - timeOffset: int = 0 - overallTimeWatched: int = 0 - timesWatched: int = 0 - flaggedWatched: int = 0 - duration: int = 0 - video_id: str | None = None - watched: str | None = None - noNotif: bool = False - season: int = 0 - episode: int = 0 - - @field_validator("lastWatched", mode="before") - @classmethod - def parse_last_watched(cls, v): - if isinstance(v, str): - try: - return datetime.fromisoformat(v.replace("Z", "+00:00")) - except ValueError: - return None - return v - - -class StremioLibraryItem(BaseModel): - """Represents a raw item from Stremio library.""" - - id: str = Field(..., alias="_id") - type: str - name: str - state: StremioState = Field(default_factory=StremioState) - mtime: str = Field(default="", alias="_mtime") - poster: str | None = None - temp: bool - removed: bool - - # Enriched fields (not in raw Stremio JSON, added by our service) - is_loved: bool = Field(default=False, alias="_is_loved") - is_liked: bool = Field(default=False, alias="_is_liked") - interest_score: float = Field(default=0.0, alias="_interest_score") - - class Config: - populate_by_name = True - - -class ScoredItem(BaseModel): - """ - A processed item with calculated scores. - This is the output of the ScoringService. - """ - - item: StremioLibraryItem - score: float - completion_rate: float - is_rewatched: bool - is_recent: bool - source_type: str # 'loved' | 'watched' | 'liked' diff --git a/app/models/token.py b/app/models/token.py deleted file mode 100644 index 80f6b3a..0000000 --- a/app/models/token.py +++ /dev/null @@ -1,11 +0,0 @@ -from pydantic import BaseModel - - -class UserSettings(BaseModel): - pass - - -class Credentials(BaseModel): - authKey: str - email: str - user_settings: UserSettings diff --git a/app/services/auth.py b/app/services/auth.py new file mode 100644 index 0000000..7db117f --- /dev/null +++ b/app/services/auth.py @@ -0,0 +1,231 @@ +from datetime import datetime, timezone + +from fastapi import HTTPException +from loguru import logger + +from app.api.models.tokens import TokenRequest, TokenResponse +from app.core.config import settings +from app.core.security import redact_token +from app.core.settings import UserSettings, get_default_settings +from app.services.stremio.service import StremioBundle +from app.services.token_store import token_store + + +class AuthService: + async def resolve_auth_key(self, credentials: dict, token: str | None = None) -> str | None: + """Validate auth key. If expired, try email+password login. Update store on refresh.""" + bundle = StremioBundle() + try: + return await self.resolve_auth_key_with_bundle(bundle, credentials, token) + finally: + await bundle.close() + + async def resolve_auth_key_with_bundle( + self, + bundle: StremioBundle, + credentials: dict, + token: str | None = None, + ) -> str | None: + """Validate auth key with an existing Stremio bundle.""" + auth_key = (credentials.get("authKey") or "").strip() or None + email = (credentials.get("email") or "").strip() or None + password = (credentials.get("password") or "").strip() or None + + if auth_key and auth_key.startswith('"') and auth_key.endswith('"'): + auth_key = auth_key[1:-1].strip() + + # 1. Try existing auth key + if auth_key: + try: + await bundle.auth.get_user_info(auth_key) + return auth_key + except Exception: + logger.info("Stremio auth key expired or invalid, attempting refresh with credentials") + + # 2. Try login if auth key failed or wasn't provided + if email and password: + try: + new_key = await bundle.auth.login(email, password) + if token and new_key != auth_key: + existing_data = await self.get_credentials(token) + if existing_data: + existing_data["authKey"] = new_key + await token_store.update_user_data(token, existing_data) + return new_key + except Exception as e: + logger.error(f"Stremio login failed: {e}") + return None + + return None + + async def require_auth_key(self, bundle: StremioBundle, credentials: dict, token: str | None = None) -> str: + """Resolve auth key or raise a user-facing error.""" + auth_key = await self.resolve_auth_key_with_bundle(bundle, credentials, token) + if not auth_key: + raise HTTPException(status_code=401, detail="Stremio session expired. Please reconfigure.") + return auth_key + + async def get_credentials(self, token: str) -> dict | None: + """Get user credentials from token store.""" + return await token_store.get_user_data(token) + + async def store_credentials(self, user_id: str, payload: dict) -> str: + """Store credentials, return token.""" + # Ensure last_updated is present if it's a new user + if "last_updated" not in payload: + token = token_store.get_token_from_user_id(user_id) + existing = await self.get_credentials(token) + if existing: + payload["last_updated"] = existing.get("last_updated") + else: + payload["last_updated"] = datetime.now(timezone.utc).isoformat() + + return await token_store.store_user_data(user_id, payload) + + async def get_stremio_user_data(self, payload: TokenRequest) -> tuple[str, str, str]: + """ + Authenticates with Stremio and returns (user_id, email, auth_key). + """ + creds = payload.model_dump() + auth_key = await self.resolve_auth_key(creds) + + if not auth_key: + raise HTTPException( + status_code=400, + detail="Failed to verify Stremio identity. Provide valid credentials.", + ) + + bundle = StremioBundle() + try: + user_info = await bundle.auth.get_user_info(auth_key) + user_id = user_info["user_id"] + resolved_email = user_info.get("email", payload.email or "") + return user_id, resolved_email, auth_key + except Exception as e: + logger.error(f"Stremio identity verification failed: {e}") + raise HTTPException(status_code=400, detail=f"Failed to verify Stremio identity: {e}") + finally: + await bundle.close() + + async def create_user_token(self, payload: TokenRequest) -> tuple[TokenResponse, str, UserSettings]: + """ + Main logic for creating or updating a user token. + + Returns: + Tuple of (TokenResponse, resolved_auth_key, user_settings) so the + caller can trigger caching without re-fetching credentials. + """ + # 1. Authenticate and get user info + user_id, resolved_email, stremio_auth_key = await self.get_stremio_user_data(payload) + + # 2. Check if user already exists + token = token_store.get_token_from_user_id(user_id) + existing_data = await self.get_credentials(token) + + # 3. Prepare payload + user_settings = self._build_user_settings(payload) + payload_to_store = { + "authKey": stremio_auth_key, + "email": resolved_email, + "settings": user_settings.model_dump(), + } + if payload.password: + payload_to_store["password"] = payload.password.strip() + + if existing_data: + payload_to_store["last_updated"] = existing_data.get("last_updated") + + # 4. Store user data + token = await self.store_credentials(user_id, payload_to_store) + + # If watch_history_source changed (or any other setting that affects + # the profile), drop cached profiles so the next catalog request + # rebuilds from the new source instead of serving the stale cache. + if existing_data: + try: + from app.services.user_cache import user_cache as _user_cache + + old_settings = existing_data.get("settings") or {} + old_source = old_settings.get("watch_history_source", "stremio") + if old_source != user_settings.watch_history_source: + for ct in ("movie", "series"): + await _user_cache.invalidate_profile(token, ct) + await _user_cache.invalidate_watched_sets(token, ct) + await _user_cache.invalidate_all_catalogs(token) + logger.info( + f"[{redact_token(token)}] watch_history_source changed " + f"'{old_source}' -> '{user_settings.watch_history_source}'; cleared profile/catalog caches." + ) + except Exception as e: + logger.warning(f"[{redact_token(token)}] Failed to invalidate caches on source change: {e}") + + # 5. Build response + base_url = settings.HOST_NAME + manifest_url = f"{base_url}/{token}/manifest.json" + expires_in = settings.TOKEN_TTL_SECONDS if settings.TOKEN_TTL_SECONDS > 0 else None + + response = TokenResponse( + token=token, + manifestUrl=manifest_url, + expiresInSeconds=expires_in, + ) + return response, stremio_auth_key, user_settings + + def _build_user_settings(self, payload: TokenRequest) -> UserSettings: + default_settings = get_default_settings() + return UserSettings( + language=payload.language or default_settings.language, + catalogs=payload.catalogs if payload.catalogs else default_settings.catalogs, + poster_rating=payload.poster_rating, + excluded_movie_genres=payload.excluded_movie_genres, + excluded_series_genres=payload.excluded_series_genres, + year_min=payload.year_min, + year_max=payload.year_max, + popularity=payload.popularity, + sorting_order=payload.sorting_order, + simkl_api_key=payload.simkl_api_key, + gemini_api_key=payload.gemini_api_key, + tmdb_api_key=payload.tmdb_api_key, + trakt_access_token=payload.trakt_access_token, + trakt_refresh_token=payload.trakt_refresh_token, + trakt_token_expires_at=payload.trakt_token_expires_at, + simkl_access_token=payload.simkl_access_token, + watch_history_source=payload.watch_history_source, + ) + + async def get_identity_with_settings(self, payload: TokenRequest) -> dict: + """Fetch Stremio identity and associated user settings if they exist.""" + user_id, email, _ = await self.get_stremio_user_data(payload) + + token = token_store.get_token_from_user_id(user_id) + existing_data = await self.get_credentials(token) + exists = bool(existing_data) + + response = {"user_id": user_id, "email": email, "exists": exists} + + if exists and existing_data: + # Reconstruct UserSettings to ensure defaults are included for old accounts + raw_settings = existing_data.get("settings", {}) + try: + user_settings = UserSettings(**raw_settings) + response["settings"] = user_settings.model_dump() + except Exception as e: + logger.warning(f"Failed to normalize settings for user {user_id}: {e}") + response["settings"] = raw_settings + + return response + + async def delete_user_account(self, payload: TokenRequest) -> None: + """Deletes user account and associated data.""" + user_id, _, _ = await self.get_stremio_user_data(payload) + token = token_store.get_token_from_user_id(user_id) + + existing_data = await self.get_credentials(token) + if not existing_data: + raise HTTPException(status_code=404, detail="Account not found.") + + await token_store.delete_token(token) + logger.info(f"[{redact_token(token)}] Token deleted for user {user_id}") + + +auth_service = AuthService() diff --git a/app/services/catalog.py b/app/services/catalog.py deleted file mode 100644 index 73cb55f..0000000 --- a/app/services/catalog.py +++ /dev/null @@ -1,343 +0,0 @@ -import asyncio -import random -from datetime import datetime, timezone -from typing import Any - -from loguru import logger - -from app.core.constants import DISCOVER_ONLY_EXTRA -from app.core.settings import CatalogConfig, UserSettings -from app.services.interest_summary import interest_summary_service -from app.services.profile.integration import ProfileIntegration -from app.services.row_generator import RowGeneratorService -from app.services.scoring import ScoringService -from app.services.tmdb.service import get_tmdb_service -from app.services.user_cache import user_cache -from app.utils.catalog import get_catalogs_from_config - - -class DynamicCatalogService: - """ - Generates dynamic catalog rows based on user library and preferences. - """ - - def __init__(self, language: str = "en-US", tmdb_api_key: str | None = None): - self.tmdb_service = get_tmdb_service(language=language, api_key=tmdb_api_key) - self.scoring_service = ScoringService() - self.profile_integration = ProfileIntegration(language=language, tmdb_api_key=tmdb_api_key) - self.row_generator = RowGeneratorService(tmdb_service=self.tmdb_service) - self.PROFILE_MAX_ITEMS = 50 - - @staticmethod - def normalize_type(type_): - return "series" if type_ == "tv" else type_ - - def build_catalog_entry(self, item, label, config_id, display_at_home: bool = True): - item_id = item.get("_id", "") - # Use watchly.{config_id}.{item_id} format for better organization - if config_id in ["watchly.item", "watchly.loved", "watchly.watched"]: - # New Item-based catalog format - catalog_id = f"{config_id}.{item_id}" - elif item_id.startswith("tt") and config_id in ["watchly.loved", "watchly.watched"]: - catalog_id = f"{config_id}.{item_id}" - else: - catalog_id = item_id - - title = item.get("name") or "" - - extra = DISCOVER_ONLY_EXTRA if not display_at_home else [] - - return { - "type": self.normalize_type(item.get("type")), - "id": catalog_id, - "name": f"{label} {title}".strip(), - # Translate only the label; keep the work title unchanged (see translation.apply_catalog_translation). - "_catalog_name_prefix": label, - "_catalog_name_suffix": title, - "extra": extra, - } - - def _get_smart_scored_items(self, library_items: dict, content_type: str, max_items: int = 50) -> list: - """ - Get smart sampled items for profile building. - Always includes all loved/liked/added items, then top watched items by interest_score. - - Args: - library_items: Library items dict - content_type: Type of content (movie/series) - max_items: Maximum items to return (default: 50) - - Returns: - List of ScoredItem objects - """ - all_items = ( - library_items.get("loved", []) - + library_items.get("liked", []) - + library_items.get("watched", []) - + library_items.get("added", []) - ) - typed_items = [it for it in all_items if it.get("type") == content_type] - - if not typed_items: - return [] - - # Get added items (strong signal - user wants to watch these) - added_item_ids = {it.get("_id") for it in library_items.get("added", [])} - added_items = [it for it in typed_items if it.get("_id") in added_item_ids] - - # Separate loved/liked from watched items (excluding added) - loved_liked_items = [ - it - for it in typed_items - if (it.get("_is_loved") or it.get("_is_liked")) and it.get("_id") not in added_item_ids - ] - watched_items = [ - it - for it in typed_items - if not (it.get("_is_loved") or it.get("_is_liked") or it.get("_id") in added_item_ids) - ] - - # Always include all loved/liked/added items (score them) - # These are strong signals of user intent - strong_signal_items = loved_liked_items + added_items - strong_signal_scored = [self.scoring_service.process_item(it) for it in strong_signal_items] - - # For watched items, score them and sort by interest_score - watched_scored = [self.scoring_service.process_item(it) for it in watched_items] - watched_scored.sort(key=lambda x: x.score, reverse=True) - - # Combine: all loved/liked/added + top watched items by score - # Limit total to max_items - remaining_slots = max(0, max_items - len(strong_signal_scored)) - top_watched = watched_scored[:remaining_slots] - - return strong_signal_scored + top_watched - - async def get_theme_based_catalogs( - self, - library_items: dict, - user_settings: UserSettings | None = None, - enabled_movie: bool = True, - enabled_series: bool = True, - display_at_home: bool = True, - token: str | None = None, - ) -> list[dict]: - """Build thematic catalogs by profiling items using smart sampling.""" - # 1. Prepare Scored History using smart sampling (loved/liked + top watched by score) - # We'll get items per content type in the generation function - - # 2. Extract Genre Filters - excluded_movie_genres = [] - excluded_series_genres = [] - gemini_api_key = None - if user_settings: - excluded_movie_genres = [int(g) for g in user_settings.excluded_movie_genres] - excluded_series_genres = [int(g) for g in user_settings.excluded_series_genres] - gemini_api_key = user_settings.gemini_api_key - - logger.info( - f"[Theme Catalogs] gemini_api_key={'SET' if gemini_api_key else 'NONE'}," - f" token={'SET' if token else 'NONE'}" - ) - - # 3. Generate Rows - async def _generate_for_type(media_type: str, genres: list[int]): - logger.info(f"[Theme Catalogs] _generate_for_type called for {media_type}") - - # Build profile using new system - profile, _, _ = await self.profile_integration.build_profile_from_library( - library_items, media_type, None, None - ) - if not profile: - logger.warning(f"Failed to build profile for {media_type}") - return media_type, [] - - # Generate interest summary if API key is present. - if gemini_api_key and token: - try: - logger.info(f"Generating interest summary for {media_type}...") - summary = await interest_summary_service.generate_summary(profile, gemini_api_key) - if summary: - profile.interest_summary = summary - logger.info(f"Interest summary generated for {media_type}: {summary[:80]}...") - else: - logger.warning(f"Interest summary generation returned empty for {media_type}") - except Exception as e: - logger.warning(f"Failed to generate interest summary for {media_type}: {e}") - else: - logger.info( - f"[Theme Catalogs] Skipping summary: gemini_api_key={'SET' if gemini_api_key else 'NONE'}," - f" token={'SET' if token else 'NONE'}" - ) - - # Always save the updated profile (with or without summary) - if token: - try: - await user_cache.set_profile(token, media_type, profile) - logger.info(f"Saved profile for {media_type} (has_summary={profile.interest_summary is not None})") - except Exception as e: - logger.warning(f"Failed to save profile for {media_type}: {e}") - - try: - catalogs = await self.row_generator.generate_rows(profile, media_type, api_key=gemini_api_key) - return media_type, catalogs - except Exception as e: - logger.error(f"Failed to generate thematic rows for {media_type}: {e}") - raise e - - tasks = [] - if enabled_movie: - tasks.append(_generate_for_type("movie", excluded_movie_genres)) - if enabled_series: - tasks.append(_generate_for_type("series", excluded_series_genres)) - - results = await asyncio.gather(*tasks, return_exceptions=True) - - # 4. Assembly with error handling - catalogs = [] - - extra = DISCOVER_ONLY_EXTRA if not display_at_home else [] - - for result in results: - if isinstance(result, Exception): - continue - media_type, rows = result - for row in rows: - catalogs.append({"type": media_type, "id": row.id, "name": row.title, "extra": extra}) - - return catalogs - - async def get_dynamic_catalogs( - self, library_items: dict, user_settings: UserSettings | None = None, token: str | None = None - ) -> list[dict]: - """Generate all dynamic catalog rows based on enabled configurations.""" - catalogs = [] - if not user_settings: - return catalogs - - # 1. Resolve Configs - theme_cfg, loved_cfg, watched_cfg = self._resolve_catalog_configs(user_settings) - - # 2. Add Thematic Catalogs - if theme_cfg and theme_cfg.enabled: - # Filter theme catalogs by enabled_movie/enabled_series - enabled_movie = getattr(theme_cfg, "enabled_movie", True) - enabled_series = getattr(theme_cfg, "enabled_series", True) - display_at_home = getattr(theme_cfg, "display_at_home", True) - theme_catalogs = await self.get_theme_based_catalogs( - library_items, user_settings, enabled_movie, enabled_series, display_at_home, token - ) - catalogs.extend(theme_catalogs) - - # 3. Add Item-Based Catalogs (Movies & Series) - for mtype in ["movie", "series"]: - await self._add_item_based_rows(catalogs, library_items, mtype, loved_cfg, watched_cfg) - - # 4. Add watchly.rec catalog - catalogs.extend(get_catalogs_from_config(user_settings, "watchly.rec", "Top Picks for You", True, True)) - - # 5. Add watchly.creators catalog - catalogs.extend( - get_catalogs_from_config(user_settings, "watchly.creators", "From your favourite Creators", False, False) - ) - - # 6. Add watchly.all.loved catalog - catalogs.extend( - get_catalogs_from_config(user_settings, "watchly.all.loved", "Based on what you loved", True, True) - ) - - # 7. Add watchly.liked.all catalog - catalogs.extend( - get_catalogs_from_config(user_settings, "watchly.liked.all", "Based on what you liked", True, True) - ) - - return catalogs - - def _resolve_catalog_configs(self, user_settings: UserSettings) -> tuple[Any, Any, Any]: - """Extract and fallback catalog configurations from user settings.""" - cfg_map = {c.id: c for c in user_settings.catalogs} - - theme = cfg_map.get("watchly.theme") - loved = cfg_map.get("watchly.loved") - watched = cfg_map.get("watchly.watched") - - # Fallback for old settings format (watchly.item) - if not loved and not watched: - old_item = cfg_map.get("watchly.item") - if old_item and old_item.enabled: - loved = CatalogConfig(id="watchly.loved", name=None, enabled=True) - watched = CatalogConfig(id="watchly.watched", name=None, enabled=True) - - return theme, loved, watched - - def _parse_item_last_watched(self, item: dict) -> datetime: - """Helper to extract and parse the most relevant activity date for an item.""" - val = item.get("state", {}).get("lastWatched") - if val: - try: - if isinstance(val, str): - return datetime.fromisoformat(val.replace("Z", "+00:00")) - return val - except (ValueError, TypeError): - pass - - # Fallback to mtime - val = item.get("_mtime") - if val: - try: - return datetime.fromisoformat(str(val).replace("Z", "+00:00")) - except (ValueError, TypeError): - pass - return datetime.min.replace(tzinfo=timezone.utc) - - async def _add_item_based_rows( - self, - catalogs: list, - library_items: dict, - content_type: str, - loved_config, - watched_config, - ): - # Check if this content type is enabled for the configs - def is_type_enabled(config, content_type: str) -> bool: - if not config: - return False - if content_type == "movie": - return getattr(config, "enabled_movie", True) - elif content_type == "series": - return getattr(config, "enabled_series", True) - return True - - # 1. More Like <Loved Item> - last_loved = None # Initialize for the watched check - if loved_config and loved_config.enabled and is_type_enabled(loved_config, content_type): - loved = [i for i in library_items.get("loved", []) if i.get("type") == content_type] - loved.sort(key=self._parse_item_last_watched, reverse=True) - - # gather random last loved from last 3 items - last_loved = random.choice(loved[:3]) if loved else None - if last_loved: - label = loved_config.name if loved_config.name else "More like" - loved_config_display_at_home = getattr(loved_config, "display_at_home", True) - catalogs.append( - self.build_catalog_entry(last_loved, label, "watchly.loved", loved_config_display_at_home) - ) - - # 2. Because you watched <Watched Item> - if watched_config and watched_config.enabled and is_type_enabled(watched_config, content_type): - watched = [i for i in library_items.get("watched", []) if i.get("type") == content_type] - watched.sort(key=self._parse_item_last_watched, reverse=True) - - # watched cannot be similar to loved - if last_loved: - watched = [i for i in watched if i.get("_id") != last_loved.get("_id")] - - # gather random last watched from last 3 items - last_watched = random.choice(watched[:3]) if watched else None - - if last_watched: - label = watched_config.name if watched_config.name else "Because you watched" - watched_config_display_at_home = getattr(watched_config, "display_at_home", True) - catalogs.append( - self.build_catalog_entry(last_watched, label, "watchly.watched", watched_config_display_at_home) - ) diff --git a/app/services/catalog_definitions.py b/app/services/catalog_definitions.py new file mode 100644 index 0000000..8756b30 --- /dev/null +++ b/app/services/catalog_definitions.py @@ -0,0 +1,376 @@ +import asyncio +import random +from datetime import datetime, timezone +from typing import Any, cast + +from loguru import logger + +from app.core.constants import DISCOVER_ONLY_EXTRA +from app.core.settings import CatalogConfig, UserSettings +from app.models.library import LibraryCollection +from app.services.profile.service import ProfileService +from app.services.row_generator import RowGeneratorService +from app.services.tmdb.service import get_tmdb_service +from app.services.user_cache import user_cache + + +def get_catalogs_from_config( + user_settings: UserSettings, + cat_id: str, + default_name: str, + default_movie: bool, + default_series: bool, +) -> list[dict[str, Any]]: + catalogs = [] + config = next((c for c in user_settings.catalogs if c.id == cat_id), None) + + if config and config.enabled: + name = config.name if config.name else default_name + enabled_movie = getattr(config, "enabled_movie", default_movie) + enabled_series = getattr(config, "enabled_series", default_series) + display_at_home = getattr(config, "display_at_home", True) + extra = DISCOVER_ONLY_EXTRA if not display_at_home else [] + + if enabled_movie: + catalogs.append({"type": "movie", "id": cat_id, "name": name, "extra": extra}) + if enabled_series: + catalogs.append({"type": "series", "id": cat_id, "name": name, "extra": extra}) + + return catalogs + + +def get_config_id(catalog: dict[str, Any]) -> str | None: + catalog_id = catalog.get("id", "") + if catalog_id.startswith("watchly.theme."): + return "watchly.theme" + if catalog_id.startswith("watchly.item."): + return "watchly.item" + # Legacy stored manifests still emit watchly.loved.* / watchly.watched.* — + # map them to the unified watchly.item config so user ordering keeps working. + if catalog_id.startswith("watchly.loved.") or catalog_id.startswith("watchly.watched."): + return "watchly.item" + return catalog_id + + +def sort_catalogs(catalogs: list[dict[str, Any]], user_settings: UserSettings) -> list[dict[str, Any]]: + """Sort catalogs according to user settings and content-type order.""" + if not user_settings: + return catalogs + + order_map = {c.id: i for i, c in enumerate(user_settings.catalogs)} + + def get_setting_index(catalog: dict[str, Any]) -> int: + config_id = get_config_id(catalog) + if config_id is None: + return 999 + return order_map.get(config_id, 999) + + sorting_order = getattr(user_settings, "sorting_order", "default") + + if sorting_order == "movies_first": + return sorted( + catalogs, + key=lambda x: ( + 0 if x.get("type") == "movie" else 1, + get_setting_index(x), + ), + ) + + if sorting_order == "series_first": + return sorted( + catalogs, + key=lambda x: ( + 0 if x.get("type") == "series" else 1, + get_setting_index(x), + ), + ) + + return sorted(catalogs, key=get_setting_index) + + +class DynamicCatalogService: + """Generates catalog definitions from user history and settings.""" + + def __init__(self, language: str = "en-US", tmdb_api_key: str | None = None): + self.language = language + self.tmdb_api_key = tmdb_api_key + tmdb_service = get_tmdb_service(language=language, api_key=tmdb_api_key) + self.profile_service = ProfileService(language=language, tmdb_api_key=tmdb_api_key) + self.row_generator = RowGeneratorService(tmdb_service=tmdb_service) + + @staticmethod + def normalize_type(type_: str) -> str: + return "series" if type_ == "tv" else type_ + + def build_catalog_entry( + self, + item, + label: str, + config_id: str, + display_at_home: bool = True, + ) -> dict[str, Any]: + from app.models.library import StremioLibraryItem + + # Support both typed items and raw dicts + if isinstance(item, StremioLibraryItem): + item_id = item.id + item_type = item.type + item_name = item.name + else: + item_id = item.get("_id", "") + item_type = item.get("type", "") + item_name = item.get("name", "") + + if config_id == "watchly.item": + catalog_id = f"{config_id}.{item_id}" + else: + catalog_id = item_id + + extra = DISCOVER_ONLY_EXTRA if not display_at_home else [] + return { + "type": self.normalize_type(item_type), + "id": catalog_id, + "name": f"{label} {item_name}", + "_catalog_name_prefix": label, + "_catalog_name_suffix": item_name, + "extra": extra, + } + + async def get_dynamic_catalogs( + self, + library_items: LibraryCollection, + user_settings: UserSettings | None = None, + token: str | None = None, + ) -> list[dict[str, Any]]: + """Generate all dynamic catalog rows based on enabled configurations.""" + catalogs: list[dict[str, Any]] = [] + if not user_settings: + return catalogs + + theme_cfg, item_cfg = self._resolve_catalog_configs(user_settings) + + if theme_cfg and theme_cfg.enabled: + enabled_movie = getattr(theme_cfg, "enabled_movie", True) + enabled_series = getattr(theme_cfg, "enabled_series", True) + display_at_home = getattr(theme_cfg, "display_at_home", True) + theme_catalogs = await self._build_theme_catalogs( + library_items, + user_settings, + enabled_movie, + enabled_series, + display_at_home, + token, + ) + catalogs.extend(theme_catalogs) + + for mtype in ["movie", "series"]: + await self._add_item_based_rows(catalogs, library_items, mtype, item_cfg) + + catalogs.extend(get_catalogs_from_config(user_settings, "watchly.rec", "Top Picks for You", True, True)) + catalogs.extend( + get_catalogs_from_config( + user_settings, + "watchly.creators", + "From your favourite Creators", + False, + False, + ) + ) + catalogs.extend( + get_catalogs_from_config( + user_settings, + "watchly.all.loved", + "Based on what you loved", + True, + True, + ) + ) + catalogs.extend( + get_catalogs_from_config( + user_settings, + "watchly.liked.all", + "Based on what you liked", + True, + True, + ) + ) + + return catalogs + + # --- Theme catalog building (was ThemeCatalogService) --- + + async def _build_theme_catalogs( + self, + library_items: LibraryCollection, + user_settings: UserSettings | None, + enabled_movie: bool, + enabled_series: bool, + display_at_home: bool, + token: str | None, + ) -> list[dict[str, Any]]: + gemini_api_key = user_settings.gemini_api_key if user_settings else None + + tasks = [] + if enabled_movie: + tasks.append(self._build_theme_rows_for_type(library_items, "movie", gemini_api_key, token, user_settings)) + if enabled_series: + tasks.append(self._build_theme_rows_for_type(library_items, "series", gemini_api_key, token, user_settings)) + + results = await asyncio.gather(*tasks, return_exceptions=True) + catalogs: list[dict[str, Any]] = [] + extra = DISCOVER_ONLY_EXTRA if not display_at_home else [] + + for result in results: + if not isinstance(result, tuple): + continue + media_type, rows = cast(tuple[str, list[Any]], result) + for row in rows: + catalogs.append( + { + "type": media_type, + "id": row.id, + "name": row.title, + "extra": extra, + } + ) + + return catalogs + + async def _build_theme_rows_for_type( + self, + library_items: LibraryCollection, + media_type: str, + gemini_api_key: str | None, + token: str | None, + user_settings: UserSettings | None = None, + ) -> tuple[str, list[Any]]: + logger.info(f"[Theme Catalogs] Building rows for {media_type}") + + # Try cached profile first, build fresh if missing (honors watch_history_source). + profile = None + if token: + profile = await user_cache.get_profile(token, media_type) + + if not profile: + if token: + profile, _, _ = await self.profile_service.build_and_cache_profile( + token, media_type, library_items, None, None, user_settings=user_settings + ) + else: + profile, _, _ = await self.profile_service.build_profile_from_library( + library_items, media_type, None, None + ) + + if not profile: + logger.warning(f"Failed to build profile for {media_type}") + return media_type, [] + + rows = await self.row_generator.generate_rows(profile, media_type, api_key=gemini_api_key) + return media_type, rows + + # --- Item-based rows --- + + def _resolve_catalog_configs(self, user_settings: UserSettings) -> tuple[Any, Any]: + cfg_map = {c.id: c for c in user_settings.catalogs} + theme = cfg_map.get("watchly.theme") + item = cfg_map.get("watchly.item") + + # Legacy migration: users created before the loved/watched merge still + # have separate `watchly.loved` and `watchly.watched` entries in their + # saved settings. Synthesize a watchly.item config from whichever is + # present so they don't lose the catalog on first load after the + # upgrade. Donor preference: loved over watched (loved was opt-in + # branded as the more intentional signal). + if not item: + legacy_loved = cfg_map.get("watchly.loved") + legacy_watched = cfg_map.get("watchly.watched") + donor = legacy_loved or legacy_watched + if donor: + enabled = bool((legacy_loved and legacy_loved.enabled) or (legacy_watched and legacy_watched.enabled)) + item = CatalogConfig( + id="watchly.item", + name="Because you Watched/Loved", + enabled=enabled, + enabled_movie=getattr(donor, "enabled_movie", True), + enabled_series=getattr(donor, "enabled_series", True), + display_at_home=getattr(donor, "display_at_home", True), + shuffle=getattr(donor, "shuffle", False), + ) + + return theme, item + + def _parse_item_last_watched(self, item) -> datetime: + from app.models.library import StremioLibraryItem + + if isinstance(item, StremioLibraryItem): + if item.state.lastWatched: + return item.state.lastWatched + if item.mtime: + try: + return datetime.fromisoformat(str(item.mtime).replace("Z", "+00:00")) + except (ValueError, TypeError): + pass + return datetime.min.replace(tzinfo=timezone.utc) + + # Fallback for raw dicts + val = item.get("state", {}).get("lastWatched") + if val: + try: + if isinstance(val, str): + return datetime.fromisoformat(val.replace("Z", "+00:00")) + return val + except (ValueError, TypeError): + pass + + val = item.get("_mtime") + if val: + try: + return datetime.fromisoformat(str(val).replace("Z", "+00:00")) + except (ValueError, TypeError): + pass + return datetime.min.replace(tzinfo=timezone.utc) + + async def _add_item_based_rows( + self, + catalogs: list[dict[str, Any]], + library_items: LibraryCollection, + content_type: str, + item_config: Any, + ) -> None: + """Emit one item-based row per content type. + + Seed selection: take the 3 most-recent loved items and the 3 most-recent + watched items, combine, and pick uniformly at random. The label + ("Because you loved X" vs "Because you watched X") follows the bucket + the chosen seed came from. The configured `name` is for the FE display + only — the served catalog title is always one of the two dynamic + labels — so the configure page disables renaming this catalog. + """ + if not item_config or not item_config.enabled: + return + + if content_type == "movie" and not getattr(item_config, "enabled_movie", True): + return + if content_type == "series" and not getattr(item_config, "enabled_series", True): + return + + loved = [i for i in library_items.loved if i.type == content_type] + watched = [i for i in library_items.watched if i.type == content_type] + loved.sort(key=self._parse_item_last_watched, reverse=True) + watched.sort(key=self._parse_item_last_watched, reverse=True) + + loved_pool = loved[:3] + watched_pool = watched[:3] + # Tag each candidate with its origin bucket so the label can follow + # the actual pick rather than re-checking flags after the fact. + candidates: list[tuple[Any, bool]] = [(i, True) for i in loved_pool] + candidates += [(i, False) for i in watched_pool] + + if not candidates: + return + + seed, seed_is_loved = random.choice(candidates) + label = "Because you loved" if seed_is_loved else "Because you watched" + + display_at_home = getattr(item_config, "display_at_home", True) + catalogs.append(self.build_catalog_entry(seed, label, "watchly.item", display_at_home)) diff --git a/app/services/catalog_updater.py b/app/services/catalog_updater.py index 15dd20f..bc8c62c 100644 --- a/app/services/catalog_updater.py +++ b/app/services/catalog_updater.py @@ -7,24 +7,32 @@ from app.core.config import settings from app.core.security import redact_token -from app.core.settings import UserSettings -from app.services.catalog import DynamicCatalogService -from app.services.manifest import manifest_service +from app.services.auth import auth_service from app.services.stremio.service import StremioBundle from app.services.token_store import token_store -from app.services.translation import apply_catalog_translation -from app.utils.catalog import sort_catalogs class CatalogUpdater: """ - Catalog updater that triggers updates on-demand when users request catalogs. + Triggers on-demand catalog updates by building a fresh manifest + and pushing the catalogs to Stremio's addon collection. Uses in-memory locking to prevent duplicate concurrent updates. """ def __init__(self): - # In-memory lock to prevent duplicate updates for the same token self._updating_tokens: set[str] = set() + # Retain background task handles so they don't get GC'd mid-flight, + # and so unhandled exceptions surface in logs. + self._pending_tasks: set[asyncio.Task] = set() + + def _on_task_done(self, task: asyncio.Task) -> None: + self._pending_tasks.discard(task) + try: + exc = task.exception() + except asyncio.CancelledError: + return + if exc is not None: + logger.error(f"Background catalog update task crashed: {exc!r}") def _needs_update(self, credentials: dict[str, Any]) -> bool: """Check if catalog update is needed based on last_updated timestamp.""" @@ -33,23 +41,19 @@ def _needs_update(self, credentials: dict[str, Any]) -> bool: last_updated = credentials.get("last_updated") if not last_updated: - # No timestamp means never updated, needs update return True try: - # Parse ISO format timestamp if isinstance(last_updated, str): last_update_time = datetime.fromisoformat(last_updated.replace("Z", "+00:00")) else: last_update_time = last_updated - # Check if more than 11 hours have passed (update if less than 1 hour remaining) now = datetime.now(timezone.utc) if last_update_time.tzinfo is None: last_update_time = last_update_time.replace(tzinfo=timezone.utc) time_since_update = (now - last_update_time).total_seconds() - # Update if less than 1 hour remaining until next update return time_since_update >= (settings.CATALOG_REFRESH_INTERVAL_SECONDS - 3600) except (ValueError, TypeError, AttributeError) as e: logger.warning(f"Failed to parse last_updated timestamp: {e}. Treating as needs update.") @@ -58,135 +62,81 @@ def _needs_update(self, credentials: dict[str, Any]) -> bool: async def refresh_catalogs_for_credentials( self, token: str, credentials: dict[str, Any], update_timestamp: bool = True ) -> bool: - """ - Refresh catalogs for a user's credentials. - - Args: - token: User token - credentials: User credentials dict - update_timestamp: Whether to update last_updated timestamp on success - - Returns: - True if update was successful, False otherwise - """ + """Build a fresh manifest and push the catalogs to Stremio.""" if not credentials: logger.warning(f"[{redact_token(token)}] Attempted to refresh catalogs with no credentials.") - raise HTTPException(status_code=401, detail="Invalid or expired token. Please reconfigure the addon.") + raise HTTPException( + status_code=401, + detail="Invalid or expired token. Please reconfigure the addon.", + ) - auth_key = credentials.get("authKey") - # check if auth key is valid bundle = StremioBundle() try: - try: - await bundle.auth.get_user_info(auth_key) - except Exception as e: - logger.exception(f"[{redact_token(token)}] Invalid auth key. Falling back to login: {e}") - email = credentials.get("email") - password = credentials.get("password") - if email and password: - auth_key = await bundle.auth.login(email, password) - credentials["authKey"] = auth_key - await token_store.update_user_data(token, credentials) - else: - return True # true since we won't be able to update it again. so no need to try again. + auth_key = await auth_service.resolve_auth_key_with_bundle(bundle, credentials, token) + if not auth_key: + return True - # 1. Check if addon is still installed + # Check if addon is still installed try: - addon_installed = await bundle.addons.is_addon_installed(auth_key) - if not addon_installed: - logger.info(f"[{redact_token(token)}] User has not installed addon. Removing token from redis") + if not await bundle.addons.is_addon_installed(auth_key): + logger.info(f"[{redact_token(token)}] Addon not installed, skipping update") return True except Exception as e: - logger.exception(f"[{redact_token(token)}] Failed to check if addon is installed: {e}") + logger.exception(f"[{redact_token(token)}] Failed to check addon install status: {e}") return False - # 2. Extract settings and refresh - user_settings = None - if credentials.get("settings"): - try: - user_settings = UserSettings(**credentials["settings"]) - except Exception as e: - logger.exception(f"[{redact_token(token)}] Failed to parse user settings: {e}") - # if user doesn't have setting, we can't update the catalogs. - # so no need to try again. - return True - - library_items = await manifest_service.cache_library_and_profiles(bundle, auth_key, user_settings, token) - language = user_settings.language if user_settings else "en-US" - - from app.core.settings import resolve_tmdb_api_key - - tmdb_key = resolve_tmdb_api_key(user_settings) - dynamic_catalog_service = DynamicCatalogService( - language=language, - tmdb_api_key=tmdb_key, - ) - - catalogs = await dynamic_catalog_service.get_dynamic_catalogs( - library_items=library_items, user_settings=user_settings, token=token - ) + # Reuse ManifestService to build catalogs + # (handles library caching, profile building, catalog definitions, + # translation, and sorting — no need to reimplement here) + from app.services.manifest import manifest_service - lang = user_settings.language if user_settings else None - for cat in catalogs: - await apply_catalog_translation(cat, lang) - - # sort catalogs by order in user settings - if user_settings: - catalogs = sort_catalogs(catalogs, user_settings) + manifest = await manifest_service.get_manifest_for_token(token) + catalogs = manifest.get("catalogs", []) success = await bundle.addons.update_catalogs(auth_key, catalogs) - # Update timestamp and invalidate cache only on success if success and update_timestamp: try: - # Update last_updated timestamp to current time - # This represents when the update completed successfully now = datetime.now(timezone.utc) - last_updated_str = now.replace(microsecond=0).isoformat() - credentials["last_updated"] = last_updated_str + credentials["last_updated"] = now.replace(microsecond=0).isoformat() await token_store.update_user_data(token, credentials) - logger.debug(f"[{redact_token(token)}] Updated last_updated timestamp to {last_updated_str}") + logger.debug(f"[{redact_token(token)}] Updated last_updated timestamp") except Exception as e: - logger.warning(f"[{redact_token(token)}] Failed to update last_updated timestamp: {e}") + logger.warning(f"[{redact_token(token)}] Failed to update timestamp: {e}") return success except Exception as e: logger.exception(f"[{redact_token(token)}] Failed to update catalogs in background: {e}") try: - error_msg = f"Failed to update catalogs: {str(e)}" - description = ( - f"Movie and series recommendations based on your Stremio library.\n\n⚠️ Status: Error\n{error_msg}" - ) - await bundle.addons.update_description(auth_key, description) + error_auth_key = credentials.get("authKey") + if isinstance(error_auth_key, str) and error_auth_key: + description = ( + "Movie and series recommendations based on your Stremio library.\n\n" + f"⚠️ Status: Error\nFailed to update catalogs: {e}" + ) + await bundle.addons.update_description(error_auth_key, description) except Exception as update_err: - logger.warning(f"[{redact_token(token)}] Failed to update addon description with error: {update_err}") + logger.warning(f"[{redact_token(token)}] Failed to update addon description: {update_err}") return False finally: await bundle.close() async def trigger_update(self, token: str, credentials: dict[str, Any]) -> None: - """ - Trigger a catalog update if needed. - This function checks if update is needed and fires a background task. - Uses in-memory lock to prevent duplicate updates. - """ - # Check if already updating + """Fire a background catalog update if needed. In-memory lock prevents duplicates.""" if token in self._updating_tokens: logger.debug(f"[{redact_token(token)}] Update already in progress, skipping") return - # Check if update is needed if not self._needs_update(credentials): logger.debug(f"[{redact_token(token)}] Catalog update not needed yet") return - # Add to lock and fire background update self._updating_tokens.add(token) logger.info(f"[{redact_token(token)}] Triggering catalog update") - - # Fire and forget background task - asyncio.create_task(self._update_task(token, credentials)) + task = asyncio.create_task(self._update_task(token, credentials)) + self._pending_tasks.add(task) + task.add_done_callback(self._on_task_done) async def _update_task(self, token: str, credentials: dict[str, Any]) -> None: """Background task that performs the actual catalog update.""" @@ -199,7 +149,6 @@ async def _update_task(self, token: str, credentials: dict[str, Any]) -> None: except Exception as e: logger.exception(f"[{redact_token(token)}] Catalog update task failed: {e}") finally: - # Always remove from lock self._updating_tokens.discard(token) diff --git a/app/services/cinemeta_service.py b/app/services/cinemeta_service.py index c86d1c7..7a52c76 100644 --- a/app/services/cinemeta_service.py +++ b/app/services/cinemeta_service.py @@ -7,18 +7,29 @@ class CinemetaService: def __init__(self): self.base_url = "https://v3-cinemeta.strem.io" + self._client: httpx.AsyncClient | None = None + + def _get_client(self) -> httpx.AsyncClient: + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(timeout=10.0, follow_redirects=True) + return self._client + + async def close(self) -> None: + if self._client and not self._client.is_closed: + await self._client.aclose() + self._client = None async def get_metadata(self, imdb_id: str, content_type: str) -> dict[str, any]: url = f"{self.base_url}/meta/{content_type}/{imdb_id}.json" - async with httpx.AsyncClient(timeout=10.0) as client: - try: - response = await client.get(url, follow_redirects=True) - response.raise_for_status() # Raise an exception for 4xx/5xx responses - json_response = response.json() - return json_response.get("meta", {}) - except (httpx.HTTPStatusError, httpx.RequestError, json.JSONDecodeError) as e: - logger.error(f"Error getting metadata for {imdb_id}: {e}") - return {} + client = self._get_client() + try: + response = await client.get(url) + response.raise_for_status() + json_response = response.json() + return json_response.get("meta", {}) + except (httpx.HTTPStatusError, httpx.RequestError, json.JSONDecodeError) as e: + logger.error(f"Error getting metadata for {imdb_id}: {e}") + return {} cinemeta_service = CinemetaService() diff --git a/app/services/context.py b/app/services/context.py new file mode 100644 index 0000000..a9d9c7b --- /dev/null +++ b/app/services/context.py @@ -0,0 +1,152 @@ +from dataclasses import dataclass +from typing import Any + +from fastapi import HTTPException +from loguru import logger + +from app.core.security import redact_token +from app.core.settings import UserSettings, get_default_settings +from app.models.library import LibraryCollection +from app.services.auth import auth_service +from app.services.stremio.service import StremioBundle +from app.services.token_store import token_store +from app.services.user_cache import user_cache + + +@dataclass +class UserContext: + """Everything a request handler needs about a user. + + The caller MUST call close() when done (or use as async context manager). + """ + + token: str + credentials: dict[str, Any] + user_settings: UserSettings + auth_key: str | None + library: LibraryCollection + bundle: StremioBundle + + async def close(self): + await self.bundle.close() + + async def __aenter__(self): + return self + + async def __aexit__(self, *exc): + await self.close() + + +def extract_settings(credentials: dict[str, Any]) -> UserSettings: + """Parse UserSettings from credentials, falling back to defaults.""" + settings_dict = credentials.get("settings", {}) + return UserSettings(**settings_dict) if settings_dict else get_default_settings() + + +async def load_user_context( + token: str, + *, + require_auth: bool = True, +) -> UserContext: + """Load credentials, settings, auth key, and library for a token. + + The library is sourced from `user_settings.watch_history_source`: + Trakt/Simkl history is converted to a LibraryCollection so the + "Because you watched/loved" catalogs and other library-driven recommenders + see the user's external history. On any external-fetch failure we fall + back to the Stremio library so the user still gets recommendations. + """ + if not token: + raise HTTPException( + status_code=401, + detail="Missing token. Please reconfigure the addon.", + ) + + credentials = await token_store.get_user_data(token) + if not credentials: + raise HTTPException( + status_code=401, + detail="Token not found. Please reconfigure the addon.", + ) + + user_settings = extract_settings(credentials) + bundle = StremioBundle() + + try: + if require_auth: + auth_key = await auth_service.require_auth_key(bundle, credentials, token) + else: + auth_key = await auth_service.resolve_auth_key_with_bundle(bundle, credentials, token) + + configured_source = user_settings.watch_history_source + + # Drop the cached library if it was built from a different source than + # the one currently configured — otherwise switching sources in the + # configure page would silently keep serving the old (wrong) library. + cached = await user_cache.get_library_items(token) + if cached and getattr(cached, "source", "stremio") != configured_source: + logger.info( + f"[{redact_token(token)}] Cached library source " + f"'{cached.source}' != configured '{configured_source}'; invalidating." + ) + await user_cache.invalidate_library_items(token) + cached = None + + library = cached + if not library: + library = await _fetch_library_for_source(configured_source, user_settings, token, bundle, auth_key) + if library is not None: + await user_cache.set_library_items(token, library) + + if not library: + library = LibraryCollection() + + return UserContext( + token=token, + credentials=credentials, + user_settings=user_settings, + auth_key=auth_key, + library=library, + bundle=bundle, + ) + except Exception: + await bundle.close() + raise + + +async def _fetch_library_for_source( + source: str, + user_settings: UserSettings, + token: str, + bundle: StremioBundle, + auth_key: str | None, +) -> LibraryCollection | None: + """Fetch the LibraryCollection for the configured watch_history_source. + + Trakt/Simkl: convert WatchHistory to a LibraryCollection. On failure + (missing token, revoked token, network), fall back to Stremio so the + user still sees recommendations from whatever Stremio knows about them. + Stremio: pull directly from the bundle's library service. + """ + if source in ("trakt", "simkl"): + from app.services.profile.service import ProfileService + + profile_service = ProfileService() + external = await profile_service.fetch_external_library(source, user_settings, token) + if external is not None: + logger.info( + f"[{redact_token(token)}] Built library from {source}: " + f"{len(external.loved)} loved, {len(external.liked)} liked, " + f"{len(external.watched)} watched" + ) + return external + + logger.warning( + f"[{redact_token(token)}] External {source} fetch returned no history; " "falling back to Stremio library." + ) + + if auth_key: + logger.info(f"[{redact_token(token)}] Fetching library from Stremio") + return await bundle.library.get_library_items(auth_key) + + return None diff --git a/app/services/gemini.py b/app/services/gemini.py index 52a1276..e11ebe2 100644 --- a/app/services/gemini.py +++ b/app/services/gemini.py @@ -107,7 +107,11 @@ async def generate_structured_async( contents=prompt, config=config, ) - return json.loads(response.text) + try: + return json.loads(response.text) + except json.JSONDecodeError as e: + logger.warning(f"Gemini returned non-JSON response: {e}; body={response.text[:200]!r}") + return None except Exception as e: logger.exception(f"Error generating structured content with Gemini Flash: {e}") return None diff --git a/app/services/interest_summary.py b/app/services/interest_summary.py deleted file mode 100644 index ddafe9c..0000000 --- a/app/services/interest_summary.py +++ /dev/null @@ -1,112 +0,0 @@ -from loguru import logger - -from app.models.taste_profile import TasteProfile -from app.services.gemini import gemini_service -from app.services.profile.constants import GENRE_MAP - - -class InterestSummaryService: - def _get_system_prompt(self) -> str: - return ( - "You are a film analyst and recommender system expert.\n" - "Your task is to analyze a user's taste profile data and generate an engaging " - "summary of their viewing preferences.\n\n" - "The summary should:\n" - '1. Be written in the second person ("You love...", "Your taste leans towards...").\n' - "2. Be a short paragraph: 3-5 sentences, so you can capture nuance and variety.\n" - '3. Capture the main vibe of their interests (e.g., "fast-paced action," ' - '"dark historical dramas," "lighthearted animation").\n' - "4. Prioritize genres and keywords as the strongest signals of taste.\n" - "5. Mention specific eras, countries, or runtime preferences when they add color.\n" - "6. Sound natural, premium, and personalized—like a thoughtful friend describing their taste.\n\n" - "Do NOT mention specific IDs or raw metrics. Translate the data into a narrative." - ) - - def _format_profile_data(self, profile: TasteProfile) -> str: - """Format all available profile data into a structured context string. - - Genres and keywords are primary signals; eras, countries, and runtime are context. - We include more of each so the summary can be richer and longer. - """ - parts: list[str] = [] - - # --- Primary: more genres and keywords for a richer summary --- - top_genres = profile.get_top_genres(limit=5) - genre_names = [GENRE_MAP.get(g_id, f"Unknown({g_id})") for g_id, _ in top_genres] - if genre_names: - parts.append(f"[Primary] Top Genres (strongest first): {', '.join(genre_names)}") - - top_keywords = profile.get_top_keywords(limit=15) - if top_keywords: - keyword_ids = [str(k_id) for k_id, _ in top_keywords] - parts.append(f"[Primary] Top Keyword IDs (higher = more watched): {', '.join(keyword_ids)}") - - top_countries = [country for country, _ in profile.get_top_countries(limit=2)] - if top_countries: - parts.append(f"[Context] Preferred Countries: {', '.join(top_countries)}") - - top_runtimes = sorted(profile.runtime_bucket_scores.items(), key=lambda x: x[1], reverse=True) - runtime_prefs = [bucket for bucket, _ in top_runtimes[:3]] - if runtime_prefs: - parts.append(f"[Context] Runtime Preference: {', '.join(runtime_prefs)}") - - return "\n".join(parts) - - async def generate_summary( - self, - profile: TasteProfile, - api_key: str, - keyword_names: dict[int, str] | None = None, - ) -> str: - """Generate a text summary of the user's interest profile using Gemini. - - Args: - profile: The user's TasteProfile. - api_key: Gemini API key (required). - keyword_names: Optional mapping of keyword ID -> name for richer context. - - Returns: - Generated summary string, or empty string on failure. - """ - if not api_key: - return "" - - try: - profile_text = self._format_profile_data(profile) - if not profile_text: - return "" - - # Enrich with resolved keyword names if available - if keyword_names: - top_keywords = profile.get_top_keywords(limit=12) - resolved = [keyword_names[k_id] for k_id, _ in top_keywords if k_id in keyword_names] - if resolved: - # Replace the keyword IDs line with actual names - profile_text = profile_text.replace( - next( - (line for line in profile_text.split("\n") if "Keyword IDs" in line), - "", - ), - f"[Primary] Top Keywords: {', '.join(resolved)}", - ) - - prompt = ( - "Based on the following user profile data, write an interest summary (3-5 sentences).\n" - "Focus primarily on [Primary] signals (genres and keywords); use [Context] to add " - "flavor. Make it feel personal and specific to this viewer.\n\n" - f"{profile_text}" - ) - - summary = await gemini_service.generate_flash_content_async( - prompt=prompt, - system_instruction=self._get_system_prompt(), - api_key=api_key, - ) - - return summary - except Exception as e: - logger.error(f"Failed to generate interest summary: {e}") - return "" - - -interest_summary_service = InterestSummaryService() diff --git a/app/services/language_service.py b/app/services/language_service.py new file mode 100644 index 0000000..a9903e7 --- /dev/null +++ b/app/services/language_service.py @@ -0,0 +1,46 @@ +import asyncio + +from loguru import logger + +from app.services.tmdb.service import TMDBService, get_tmdb_service + + +async def fetch_languages_list() -> list[dict[str, str]]: + tmdb_service: TMDBService = get_tmdb_service() + tasks = [ + tmdb_service.get_primary_translations(), + tmdb_service.get_languages(), + tmdb_service.get_countries(), + ] + results = await asyncio.gather(*tasks, return_exceptions=True) + for label, r in zip(("primary_translations", "languages", "countries"), results): + if isinstance(r, Exception): + logger.warning(f"TMDB {label} fetch failed: {r}") + primary_translations = results[0] if not isinstance(results[0], Exception) else [] + languages = results[1] if not isinstance(results[1], Exception) else [] + countries = results[2] if not isinstance(results[2], Exception) else [] + + language_map = {lang["iso_639_1"]: lang["english_name"] for lang in languages} + country_map = {country["iso_3166_1"]: country["english_name"] for country in countries} + + result = [] + for element in primary_translations: + # element looks like "en-US" + parts = element.split("-") + if len(parts) != 2: + continue + + lang_code, country_code = parts + language_name = language_map.get(lang_code) + country_name = country_map.get(country_code) + + if language_name and country_name: + result.append( + { + "iso_639_1": element, + "language": language_name, + "country": country_name, + } + ) + result.sort(key=lambda x: (x["iso_639_1"] != "en-US", x["language"])) + return result diff --git a/app/services/manifest.py b/app/services/manifest.py index dc3c855..ed0d593 100644 --- a/app/services/manifest.py +++ b/app/services/manifest.py @@ -1,19 +1,19 @@ +import copy from typing import Any -from fastapi import HTTPException from loguru import logger from app.core.config import settings from app.core.security import redact_token from app.core.settings import UserSettings, resolve_tmdb_api_key from app.core.version import __version__ -from app.services.catalog import DynamicCatalogService -from app.services.profile.integration import ProfileIntegration +from app.models.library import LibraryCollection +from app.services.catalog_definitions import DynamicCatalogService, sort_catalogs +from app.services.context import load_user_context +from app.services.profile.service import ProfileService from app.services.stremio.service import StremioBundle -from app.services.token_store import token_store from app.services.translation import apply_catalog_translation from app.services.user_cache import user_cache -from app.utils.catalog import cache_profile_and_watched_sets, sort_catalogs class ManifestService: @@ -27,9 +27,9 @@ def get_base_manifest() -> dict[str, Any]: "version": __version__, "name": settings.ADDON_NAME, "description": "Movie and series recommendations based on your Stremio library.", - "logo": ("https://raw.githubusercontent.com/TimilsinaBimal/Watchly/refs/heads/main/app/static/logo.png"), + "logo": ("https://raw.githubusercontent.com/TimilsinaBimal/Watchly" "/refs/heads/main/app/static/logo.png"), "background": ( - "https://raw.githubusercontent.com/TimilsinaBimal/Watchly/refs/heads/main/app/static/cover.png" + "https://raw.githubusercontent.com/TimilsinaBimal/Watchly" "/refs/heads/main/app/static/cover.png" ), "resources": ["catalog"], "types": ["movie", "series"], @@ -39,73 +39,40 @@ def get_base_manifest() -> dict[str, Any]: "stremioAddonsConfig": { "issuer": "https://stremio-addons.net", "signature": ( - "eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0..WSrhzzlj1TuDycD6QoVLuA.Dzmxzr4y83uqQF15r4tC1bB9-vtZRh1Rvy4BqgDYxu91c2esiJuov9KnnI_cboQCgZS7hjwnIqRSlQ-jEyGwXHHRerh9QklyfdxpXqNUyBgTWFzDOVdVvDYJeM_tGMmR.sezAChlWGV7lNS-t9HWB6A" # noqa + "eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2In0" + "..WSrhzzlj1TuDycD6QoVLuA" + ".Dzmxzr4y83uqQF15r4tC1bB9-vtZRh1Rvy4BqgDYxu91c2esiJuov9KnnI_cboQC" + "gZS7hjwnIqRSlQ-jEyGwXHHRerh9QklyfdxpXqNUyBgTWFzDOVdVvDYJeM_tGMmR" + ".sezAChlWGV7lNS-t9HWB6A" # noqa ), }, } - async def _resolve_auth_key(self, bundle: StremioBundle, credentials: dict[str, Any], token: str) -> str | None: - """Resolve and validate auth key, refreshing if needed.""" - auth_key = credentials.get("authKey") - email = credentials.get("email") - password = credentials.get("password") - - is_valid = False - if auth_key: - try: - await bundle.auth.get_user_info(auth_key) - is_valid = True - except Exception as e: - logger.debug(f"Auth key check failed for {email or 'unknown'}: {e}") - - if not is_valid and email and password: - try: - auth_key = await bundle.auth.login(email, password) - # Update store - credentials["authKey"] = auth_key - await token_store.update_user_data(token, credentials) - except Exception as e: - logger.error(f"Failed to refresh auth key during manifest fetch: {e}") - return None - - return auth_key - async def cache_library_and_profiles( - self, bundle: StremioBundle, auth_key: str, user_settings: UserSettings, token: str - ) -> dict[str, Any]: + self, + bundle: StremioBundle, + auth_key: str, + user_settings: UserSettings, + token: str, + ) -> LibraryCollection: + """Fetch and cache library items and profiles for a user. + + Called during token creation to pre-cache data so manifest generation is fast. """ - Fetch and cache library items and profiles for a user. - - This should be called during token creation to pre-cache data - so manifest generation is fast. - - Args: - bundle: StremioBundle instance - auth_key: Stremio auth key - user_settings: User settings - token: User token - - Returns: - Library items dictionary - """ - # Fetch library items logger.info(f"[{redact_token(token)}] Fetching library items for caching") library_items = await bundle.library.get_library_items(auth_key) - - # Cache library items using centralized cache service await user_cache.set_library_items(token, library_items) logger.debug(f"[{redact_token(token)}] Cached library items") - # Build and cache profiles for both movie and series language = user_settings.language tmdb_key = resolve_tmdb_api_key(user_settings) - integration_service = ProfileIntegration(language=language, tmdb_api_key=tmdb_key) + profile_service = ProfileService(language=language, tmdb_api_key=tmdb_key) for content_type in ["movie", "series"]: try: logger.info(f"[{redact_token(token)}] Building and caching profile for {content_type}") - _, _, _ = await cache_profile_and_watched_sets( - token, content_type, integration_service, library_items, bundle, auth_key + await profile_service.build_and_cache_profile( + token, content_type, library_items, bundle, auth_key, user_settings=user_settings ) logger.debug(f"[{redact_token(token)}] Cached profile and watched sets for {content_type}") except Exception as e: @@ -113,115 +80,52 @@ async def cache_library_and_profiles( return library_items - async def _ensure_library_and_profiles_cached( - self, bundle: StremioBundle, auth_key: str, user_settings: UserSettings, token: str - ) -> dict[str, Any]: - """Ensure library items and profiles are cached, fetching and building if needed.""" - # Try to get cached library items first - library_items = await user_cache.get_library_items(token) - - if library_items: - logger.debug(f"[{redact_token(token)}] Using cached library items for manifest") - return library_items - - # If not cached, fetch and cache - logger.info(f"[{redact_token(token)}] Library items not cached, fetching from Stremio for manifest") - return await self.cache_library_and_profiles(bundle, auth_key, user_settings, token) - - async def _build_dynamic_catalogs( - self, bundle: StremioBundle, auth_key: str, user_settings: UserSettings | None, token: str - ) -> list[dict[str, Any]]: - """Build dynamic catalogs for the manifest.""" - # check if cached, if not, fetch and cache - library_items = await user_cache.get_library_items(token) - if not library_items: - library_items = await self._ensure_library_and_profiles_cached(bundle, auth_key, user_settings, token) - await user_cache.set_library_items(token, library_items) - - tmdb_key = resolve_tmdb_api_key(user_settings) - dynamic_catalog_service = DynamicCatalogService(language=user_settings.language, tmdb_api_key=tmdb_key) - return await dynamic_catalog_service.get_dynamic_catalogs(library_items, user_settings, token=token) - - async def _translate_catalogs(self, catalogs: list[dict[str, Any]], language: str | None) -> list[dict[str, Any]]: - """Translate catalog names to target language.""" - if not language: - return catalogs - - translated_catalogs = [] - for cat in catalogs: - await apply_catalog_translation(cat, language) - translated_catalogs.append(cat) - - return translated_catalogs - - def _sort_catalogs( - self, catalogs: list[dict[str, Any]], user_settings: UserSettings | None - ) -> list[dict[str, Any]]: - """Sort catalogs according to user settings order.""" - if not user_settings: - return catalogs - - return sort_catalogs(catalogs, user_settings) - async def get_manifest_for_token(self, token: str) -> dict[str, Any]: - """ - Generate manifest for a given token. - - Args: - token: User token - - Returns: - Complete manifest dictionary - - Raises: - HTTPException: If token is invalid or credentials are missing - """ - if not token: - raise HTTPException(status_code=401, detail="Missing token. Please reconfigure the addon.") - - # Load user credentials and settings - creds = await token_store.get_user_data(token) - if not creds: - raise HTTPException(status_code=401, detail="Token not found. Please reconfigure the addon.") - - user_settings = None - try: - if creds.get("settings"): - user_settings = UserSettings(**creds["settings"]) - except Exception as e: - logger.error(f"[{redact_token(token)}] Error loading user data from token store: {e}") - raise HTTPException(status_code=401, detail="Invalid token session. Please reconfigure.") - + """Generate manifest for a given token.""" base_manifest = self.get_base_manifest() - bundle = StremioBundle() - fetched_catalogs = [] + ctx = await load_user_context(token, require_auth=False) + fetched_catalogs: list[dict[str, Any]] = [] try: - # Resolve auth key - auth_key = await self._resolve_auth_key(bundle, creds, token) - - if auth_key: - fetched_catalogs = await self._build_dynamic_catalogs(bundle, auth_key, user_settings, token) + if ctx.auth_key: + tmdb_key = resolve_tmdb_api_key(ctx.user_settings) + catalog_def_service = DynamicCatalogService(language=ctx.user_settings.language, tmdb_api_key=tmdb_key) + fetched_catalogs = await catalog_def_service.get_dynamic_catalogs( + ctx.library, ctx.user_settings, token=token + ) except Exception as e: logger.exception(f"[{redact_token(token)}] Dynamic catalog build failed: {e}") fetched_catalogs = [] finally: - await bundle.close() - - # Combine base catalogs with fetched catalogs - all_catalogs = [c.copy() for c in base_manifest["catalogs"]] + [c.copy() for c in fetched_catalogs] + await ctx.close() - # Translate catalogs - language = user_settings.language if user_settings else None - translated_catalogs = await self._translate_catalogs(all_catalogs, language) + # deepcopy: catalogs contain nested dicts/lists (extra params, options) that + # downstream code mutates (translation, sort). Shallow copies would mutate + # shared inner objects across users. + all_catalogs = [copy.deepcopy(c) for c in base_manifest["catalogs"]] + [ + copy.deepcopy(c) for c in fetched_catalogs + ] - # Sort catalogs - sorted_catalogs = self._sort_catalogs(translated_catalogs, user_settings) + language = ctx.user_settings.language + translated = await self._translate_catalogs(all_catalogs, language) + sorted_catalogs = sort_catalogs(translated, ctx.user_settings) if sorted_catalogs: base_manifest["catalogs"] = sorted_catalogs return base_manifest + async def _translate_catalogs(self, catalogs: list[dict[str, Any]], language: str | None) -> list[dict[str, Any]]: + """Translate catalog names to target language.""" + if not language: + return catalogs + + translated_catalogs = [] + for cat in catalogs: + await apply_catalog_translation(cat, language) + translated_catalogs.append(cat) + + return translated_catalogs + manifest_service = ManifestService() diff --git a/app/services/poster_ratings/rpdb.py b/app/services/poster_ratings/rpdb.py index 1e902f1..e013e30 100644 --- a/app/services/poster_ratings/rpdb.py +++ b/app/services/poster_ratings/rpdb.py @@ -7,12 +7,22 @@ class RPDBService: def __init__(self): self.base_url = "https://api.ratingposterdb.com" + self._client: httpx.AsyncClient | None = None + + def _get_client(self) -> httpx.AsyncClient: + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(timeout=10.0) + return self._client + + async def close(self) -> None: + if self._client and not self._client.is_closed: + await self._client.aclose() + self._client = None async def validate_api_key(self, api_key: str) -> bool: url = f"{self.base_url}/{api_key}/isValid" - async with httpx.AsyncClient(timeout=10.0) as client: - response = await client.get(url) - return response.status_code == 200 + response = await self._get_client().get(url) + return response.status_code == 200 def get_poster_url( self, diff --git a/app/services/poster_ratings/top_posters.py b/app/services/poster_ratings/top_posters.py index 8722bec..2f2f40b 100644 --- a/app/services/poster_ratings/top_posters.py +++ b/app/services/poster_ratings/top_posters.py @@ -13,14 +13,24 @@ def __init__(self): "User-Agent": f"Watchly/{__version__} (+https://github.com/TimilsinaBimal/Watchly)", "Accept": "application/json", } + self._client: httpx.AsyncClient | None = None + + def _get_client(self) -> httpx.AsyncClient: + if self._client is None or self._client.is_closed: + self._client = httpx.AsyncClient(timeout=10.0, headers=self.headers) + return self._client + + async def close(self) -> None: + if self._client and not self._client.is_closed: + await self._client.aclose() + self._client = None async def validate_api_key(self, api_key: str) -> bool: url = f"{self.base_url}/auth/verify/{api_key}" - async with httpx.AsyncClient(timeout=10.0) as client: - response = await client.get(url, headers=self.headers) - response.raise_for_status() - json_data = response.json() - return json_data.get("valid", False) + response = await self._get_client().get(url) + response.raise_for_status() + json_data = response.json() + return json_data.get("valid", False) def get_poster_url(self, api_key: str, provider: Literal["imdb", "tmdb", "tvdb"], item_id: str, **kwargs) -> str: url = f"{self.base_url}/{api_key}/{provider}/poster-default/{item_id}.jpg" diff --git a/app/services/profile/__init__.py b/app/services/profile/__init__.py index a9b1489..9bb00ee 100644 --- a/app/services/profile/__init__.py +++ b/app/services/profile/__init__.py @@ -1,22 +1,7 @@ -""" -Profile System - Additive, Transparent Design. +"""Profile service exports.""" -This package implements a transparent, additive user profile system. -No hidden interactions, easy to debug, powerful enough for all row types. -""" - -from app.services.profile.builder import ProfileBuilder -from app.services.profile.evidence import EvidenceCalculator -from app.services.profile.integration import ProfileIntegration -from app.services.profile.sampling import SmartSampler -from app.services.profile.scorer import ProfileScorer -from app.services.profile.vectorizer import ItemVectorizer +from app.services.profile.service import ProfileService __all__ = [ - "ProfileBuilder", - "ProfileScorer", - "EvidenceCalculator", - "ItemVectorizer", - "SmartSampler", - "ProfileIntegration", + "ProfileService", ] diff --git a/app/services/profile/builder.py b/app/services/profile/builder.py index 4face31..2f76383 100644 --- a/app/services/profile/builder.py +++ b/app/services/profile/builder.py @@ -5,8 +5,7 @@ from loguru import logger -from app.models.scoring import ScoredItem -from app.models.taste_profile import TasteProfile +from app.models.profile import ScoredItem, TasteProfile from app.services.profile.constants import ( CAP_CAST, CAP_COUNTRY, @@ -234,10 +233,10 @@ def _accumulate_features( job = "" if crew_id: - # Only count actual directors (for movies) and creators (for TV series) if job in ["director", "creator"]: weight = evidence_weight * FEATURE_WEIGHT_CREATOR profile.director_scores[crew_id] = profile.director_scores.get(crew_id, 0.0) + weight + profile.director_frequency[crew_id] = profile.director_frequency.get(crew_id, 0) + 1 if frequencies is not None: frequencies["directors"][crew_id] += 1 @@ -250,9 +249,9 @@ def _accumulate_features( position_weight = 1.0 if cast_id: - # Use evidence weight multiplied by position weight weight = evidence_weight * FEATURE_WEIGHT_CREATOR * position_weight profile.cast_scores[cast_id] = profile.cast_scores.get(cast_id, 0.0) + weight + profile.cast_frequency[cast_id] = profile.cast_frequency.get(cast_id, 0) + 1 if frequencies is not None: frequencies["cast"][cast_id] += 1 @@ -306,39 +305,20 @@ def _apply_frequency_multipliers(self, profile: TasteProfile, frequencies: dict[ @staticmethod def _apply_caps(profile: TasteProfile) -> None: """ - Apply score caps to prevent unbounded growth. - - Args: - profile: Profile to cap + Apply score caps to prevent unbounded growth (both positive and negative). """ - # Cap genres - for genre_id in profile.genre_scores: - profile.genre_scores[genre_id] = min(profile.genre_scores[genre_id], CAP_GENRE) - - # Cap keywords - for keyword_id in profile.keyword_scores: - profile.keyword_scores[keyword_id] = min(profile.keyword_scores[keyword_id], CAP_KEYWORD) - - # Cap directors - for director_id in profile.director_scores: - profile.director_scores[director_id] = min(profile.director_scores[director_id], CAP_DIRECTOR) - - # Cap cast - for cast_id in profile.cast_scores: - profile.cast_scores[cast_id] = min(profile.cast_scores[cast_id], CAP_CAST) - - # Cap eras - for era in profile.era_scores: - profile.era_scores[era] = min(profile.era_scores[era], CAP_ERA) - - # Cap countries - for country in profile.country_scores: - profile.country_scores[country] = min(profile.country_scores[country], CAP_COUNTRY) - - # Cap runtime buckets - for runtime_bucket in profile.runtime_bucket_scores: - current_score = profile.runtime_bucket_scores[runtime_bucket] - profile.runtime_bucket_scores[runtime_bucket] = min(current_score, CAP_RUNTIME) + cap_pairs = [ + (profile.genre_scores, CAP_GENRE), + (profile.keyword_scores, CAP_KEYWORD), + (profile.director_scores, CAP_DIRECTOR), + (profile.cast_scores, CAP_CAST), + (profile.era_scores, CAP_ERA), + (profile.country_scores, CAP_COUNTRY), + (profile.runtime_bucket_scores, CAP_RUNTIME), + ] + for scores, cap in cap_pairs: + for key in scores: + scores[key] = max(-cap, min(scores[key], cap)) async def update_profile_incrementally( self, diff --git a/app/services/profile/constants.py b/app/services/profile/constants.py index 2131ce9..a035db7 100644 --- a/app/services/profile/constants.py +++ b/app/services/profile/constants.py @@ -55,9 +55,6 @@ MAXIMUM_POPULARITY_SCORE: Final[float] = 100.0 # Increased from 15.0 to allow popular items -# Genre whitelist limit (top N genres) -GENRE_WHITELIST_LIMIT: Final[int] = 7 - # Runtime Bucket Boundaries (in minutes) RUNTIME_BUCKET_SHORT_MAX_SERIES: Final[int] = 30 # < 30 min RUNTIME_BUCKET_MEDIUM_MAX_SERIES: Final[int] = 60 # 30-60 min, > 60 is long diff --git a/app/services/profile/evidence.py b/app/services/profile/evidence.py index 1e0bd9d..4c06c0d 100644 --- a/app/services/profile/evidence.py +++ b/app/services/profile/evidence.py @@ -2,7 +2,7 @@ from datetime import datetime, timezone from typing import Literal -from app.models.scoring import ScoredItem +from app.models.profile import ScoredItem from app.services.profile.constants import ( EVIDENCE_WEIGHT_ADDED, EVIDENCE_WEIGHT_LIKED, @@ -12,25 +12,22 @@ RECENCY_HALF_LIFE_DAYS, ) +# Abandonment thresholds (in minutes of watch time) +_ABANDON_IGNORE_MINUTES = 15 # < 15 min: too short, ignore +_ABANDON_NEGATIVE_THRESHOLD = 0.30 # 15 min – 30%: mild negative + class EvidenceCalculator: """ Calculates evidence weights for user interactions. - Pure function: no side effects, easy to test. + Supports both legacy Stremio interaction types and explicit 1-10 ratings + from external sources (Trakt, Simkl). """ @staticmethod def get_interaction_type(item: ScoredItem) -> Literal["loved", "liked", "watched_high", "watched_medium", "added"]: - """ - Determine interaction type from scored item. - - Args: - item: ScoredItem with interaction data - - Returns: - Interaction type string - """ + """Determine interaction type from scored item.""" if item.item.is_loved: return "loved" if item.item.is_liked: @@ -39,22 +36,13 @@ def get_interaction_type(item: ScoredItem) -> Literal["loved", "liked", "watched return "watched_high" if item.completion_rate >= 0.4: return "watched_medium" - # Check if added to library (not watched, not removed, not temp) if not item.item.temp and not item.item.removed and item.completion_rate < 0.4: return "added" - return "watched_medium" # Fallback + return "watched_medium" @staticmethod def get_base_weight(interaction_type: str) -> float: - """ - Get base evidence weight for interaction type. - - Args: - interaction_type: Type of interaction - - Returns: - Base weight value - """ + """Get base evidence weight for interaction type (legacy bucket system).""" weights = { "loved": EVIDENCE_WEIGHT_LOVED, "liked": EVIDENCE_WEIGHT_LIKED, @@ -65,18 +53,46 @@ def get_base_weight(interaction_type: str) -> float: return weights.get(interaction_type, EVIDENCE_WEIGHT_WATCHED_MEDIUM) @staticmethod - def calculate_recency_multiplier(last_interaction: datetime | None) -> float: + def weight_from_rating(rating: float) -> float: """ - Calculate recency multiplier using exponential decay. + Continuous evidence weight from an explicit 1-10 rating. - Args: - last_interaction: When the interaction occurred + Positive: 5→0.3, 6→0.8, 7→1.3, 8→1.8, 9→2.5, 10→3.0 + Negative: 1→-1.5, 2→-1.0, 3→-0.5, 4→-0.1 + """ + if rating >= 5: + return max(0.1, (rating - 4) / 2) + return (rating - 5) / 2 - Returns: - Multiplier (1.0 for recent, <1.0 for old) + @staticmethod + def weight_from_completion(completion: float, watch_time_minutes: float | None = None) -> float: """ + Evidence weight for unrated items based on watch completion. + + Implements abandonment detection: + - < 15 min watched: ignore (weight 0.0) + - 15 min to 30% completion: mild negative (-0.2) + - 30%-70% completion: neutral (0.0) + - > 70% completion: positive (1.0) + """ + # If we have actual watch time, use the abandonment thresholds + if watch_time_minutes is not None and watch_time_minutes < _ABANDON_IGNORE_MINUTES: + return 0.0 + + if completion >= 0.7: + return 1.0 + if completion >= 0.3: + return 0.0 # Ambiguous — neutral + if watch_time_minutes is not None and watch_time_minutes >= _ABANDON_IGNORE_MINUTES: + return -0.2 # Gave it a fair shot and quit + # Low completion without enough info — treat as neutral + return 0.0 + + @staticmethod + def calculate_recency_multiplier(last_interaction: datetime | None) -> float: + """Calculate recency multiplier using exponential decay.""" if not last_interaction: - return 0.5 # No date = old, reduce weight + return 0.5 now = datetime.now(timezone.utc) if last_interaction.tzinfo is None: @@ -84,35 +100,74 @@ def calculate_recency_multiplier(last_interaction: datetime | None) -> float: days_ago = (now - last_interaction).days if days_ago < 0: - return 1.0 # Future date = treat as recent + return 1.0 - # Exponential decay: multiplier = exp(-days / half_life) multiplier = math.exp(-days_ago / RECENCY_HALF_LIFE_DAYS) - return max(0.1, multiplier) # Minimum 0.1 to keep some signal + return max(0.1, multiplier) @staticmethod def calculate_evidence_weight(item: ScoredItem) -> float: """ Calculate final evidence weight for an item. - Combines base weight (interaction type) with recency multiplier. - - Args: - item: ScoredItem with interaction data - - Returns: - Final evidence weight + Uses explicit rating if available (from external history sources), + otherwise falls back to the legacy interaction-type bucket system. + Abandonment detection is applied for unrated items. """ - interaction_type = EvidenceCalculator.get_interaction_type(item) - base_weight = EvidenceCalculator.get_base_weight(interaction_type) + # Check for an explicit rating (set by the WatchHistory → ScoredItem converter) + # The converter maps loved→is_loved (rating≥9) and liked→is_liked (rating≥7). + # For items with external ratings, we use the continuous scale. + has_explicit_rating = False + rating: float | None = None + + # Detect external-history items by checking the synthetic state pattern: + # External items have flaggedWatched=1 and a specific duration sentinel (6000) + # OR they have is_loved/is_liked set from external ratings. + # We use a simpler heuristic: if is_loved with flaggedWatched=1, compute from rating=9. + # For more granularity, we'll check the state for our sentinel. + state = item.item.state + + if item.item.is_loved: + # Could be Stremio loved (legacy) or external rating ≥ 9 + # Use rating-proportional weight for loved items + rating = 9.0 + has_explicit_rating = True + elif item.item.is_liked: + rating = 7.0 + has_explicit_rating = True + + if has_explicit_rating and rating is not None: + base_weight = EvidenceCalculator.weight_from_rating(rating) + else: + # Check for abandonment on unrated items + watch_time_minutes: float | None = None + if state.duration > 0 and state.timeWatched > 0: + watch_time_minutes = state.timeWatched / 60.0 + + completion = item.completion_rate + + # Use completion-based weight with abandonment detection + completion_weight = EvidenceCalculator.weight_from_completion(completion, watch_time_minutes) + + if ( + completion_weight == 0.0 + and watch_time_minutes is not None + and watch_time_minutes < _ABANDON_IGNORE_MINUTES + ): + # Too short, skip this item entirely + return 0.0 + + if completion_weight != 0.0: + base_weight = completion_weight + else: + # Fall back to legacy bucket system for ambiguous cases + interaction_type = EvidenceCalculator.get_interaction_type(item) + base_weight = EvidenceCalculator.get_base_weight(interaction_type) # Get last interaction date - last_interaction = item.item.state.lastWatched - if not last_interaction and interaction_type == "added": - # For added items, use mtime if available + last_interaction = state.lastWatched + if not last_interaction: try: - from datetime import datetime - if item.item.mtime: last_interaction = datetime.fromisoformat(item.item.mtime.replace("Z", "+00:00")) except Exception: diff --git a/app/services/profile/integration.py b/app/services/profile/integration.py deleted file mode 100644 index e987147..0000000 --- a/app/services/profile/integration.py +++ /dev/null @@ -1,233 +0,0 @@ -from typing import Any - -from loguru import logger - -from app.models.taste_profile import TasteProfile -from app.services.profile.builder import ProfileBuilder -from app.services.profile.constants import GENRE_WHITELIST_LIMIT -from app.services.profile.sampling import SmartSampler -from app.services.profile.vectorizer import ItemVectorizer -from app.services.recommendation.filtering import RecommendationFiltering -from app.services.scoring import ScoringService -from app.services.tmdb.service import get_tmdb_service -from app.services.user_cache import user_cache - - -class ProfileIntegration: - """ - Helper class to integrate taste profile services with existing systems. - """ - - def __init__(self, language: str = "en-US", tmdb_api_key: str | None = None): - self.scoring_service = ScoringService() - self.sampler = SmartSampler(self.scoring_service) - tmdb_service = get_tmdb_service(language=language, api_key=tmdb_api_key) - vectorizer = ItemVectorizer(tmdb_service) - self.builder = ProfileBuilder(vectorizer) - - async def build_profile_from_library( - self, - library_items: dict, - content_type: str, - stremio_service: Any = None, - auth_key: str | None = None, - ) -> tuple[TasteProfile | None, set[int], set[str]]: - """ - Build taste profile from library items and get watched sets. - - Args: - library_items: Library items dict from Stremio - content_type: Content type (movie/series) - stremio_service: Stremio service (optional, for watched sets) - auth_key: Auth key (optional, for watched sets) - - Returns: - Tuple of (profile, watched_tmdb, watched_imdb) - """ - # Get watched sets - watched_imdb, watched_tmdb = await RecommendationFiltering.get_exclusion_sets( - stremio_service, library_items, auth_key - ) - - # Convert library items to ScoredItems - all_items = ( - library_items.get("loved", []) - + library_items.get("liked", []) - + library_items.get("watched", []) - + library_items.get("added", []) - ) - typed_items = [it for it in all_items if it.get("type") == content_type] - - if not typed_items: - return None, watched_tmdb, watched_imdb - - # Sample items using SmartSampler (it expects raw library items dict) - library_items_dict = { - "loved": [it for it in library_items.get("loved", []) if it.get("type") == content_type], - "liked": [it for it in library_items.get("liked", []) if it.get("type") == content_type], - "watched": [it for it in library_items.get("watched", []) if it.get("type") == content_type], - "added": [it for it in library_items.get("added", []) if it.get("type") == content_type], - } - sampled = self.sampler.sample_items(library_items_dict, content_type) - - # Build profile - profile = await self.builder.build_profile(sampled, content_type=content_type) - - return profile, watched_tmdb, watched_imdb - - async def build_profile_incremental( - self, - library_items: dict, - content_type: str, - token: str, - stremio_service: Any = None, - auth_key: str | None = None, - ) -> tuple[TasteProfile | None, set[int], set[str]]: - """ - Build profile incrementally if possible, fallback to full rebuild. - - Args: - library_items: Library items dict from Stremio - content_type: Content type (movie/series) - token: User token for change detection - stremio_service: Stremio service (optional, for watched sets) - auth_key: Auth key (optional, for watched sets) - - Returns: - Tuple of (profile, watched_tmdb, watched_imdb) - """ - # Get watched sets - watched_imdb, watched_tmdb = await RecommendationFiltering.get_exclusion_sets( - stremio_service, library_items, auth_key - ) - - # Convert library items to ScoredItems for change detection - all_items = ( - library_items.get("loved", []) - + library_items.get("liked", []) - + library_items.get("watched", []) - + library_items.get("added", []) - ) - typed_items = [it for it in all_items if it.get("type") == content_type] - - if not typed_items: - return None, watched_tmdb, watched_imdb - - # Check if we can use incremental update - try: - # Check if library has changed - library_changed = await user_cache.has_library_changed(token, content_type, typed_items) - - if not library_changed: - # No changes - return existing profile - existing_profile = await user_cache.get_profile(token, content_type) - if existing_profile: - return existing_profile, watched_tmdb, watched_imdb - - # Try to get existing profile for incremental update - existing_profile = await user_cache.get_profile(token, content_type) - - if existing_profile: - # Check for removals or new items - processed_ids = existing_profile.processed_items - current_ids = {it.get("_id", it.get("id")) for it in typed_items if it.get("_id", it.get("id"))} - - # Check if this is a legacy profile (has scores but no processed_items) - is_legacy = not processed_ids and (existing_profile.genre_scores or existing_profile.director_scores) - - # If items were removed, or it's a legacy profile, we must do a full rebuild - if not processed_ids.issubset(current_ids) or is_legacy: - reason = "Legacy profile detected" if is_legacy else "Items removed from library" - logger.debug(f"[{token[:8]}...] {reason}, falling back to full rebuild") - # Fall through to full rebuild - else: - # Identify new items - new_item_ids = current_ids - processed_ids - - if not new_item_ids: - # No new items and no removals (maybe just metadata changed?) - # We can just return the existing profile - return existing_profile, watched_tmdb, watched_imdb - - logger.debug(f"[{token[:8]}...] Found {len(new_item_ids)} new items, using incremental update") - - # Filter library items to only new ones for sampling - new_library_items_dict = { - "loved": [ - it - for it in library_items.get("loved", []) - if it.get("type") == content_type and (it.get("_id") or it.get("id")) in new_item_ids - ], - "liked": [ - it - for it in library_items.get("liked", []) - if it.get("type") == content_type and (it.get("_id") or it.get("id")) in new_item_ids - ], - "watched": [ - it - for it in library_items.get("watched", []) - if it.get("type") == content_type and (it.get("_id") or it.get("id")) in new_item_ids - ], - "added": [ - it - for it in library_items.get("added", []) - if it.get("type") == content_type and (it.get("_id") or it.get("id")) in new_item_ids - ], - } - - # Sample only new items - sampled = self.sampler.sample_items(new_library_items_dict, content_type) - - if not sampled: - # Should not happen if new_item_ids is not empty, but just in case - return existing_profile, watched_tmdb, watched_imdb - - # Update existing profile incrementally - updated_profile = await self.builder.update_profile_incrementally( - existing_profile, sampled, content_type=content_type - ) - - # Update library hash to mark as processed - await user_cache.update_library_hash(token, content_type, typed_items) - - return updated_profile, watched_tmdb, watched_imdb - - except Exception as e: - logger.warning(f"[{token[:8]}...] Incremental update failed, falling back to full rebuild: {e}") - - # Fallback to full rebuild - logger.debug(f"[{token[:8]}...] Using full rebuild") - profile_tuple = await self.build_profile_from_library(library_items, content_type, stremio_service, auth_key) - profile, _, _ = profile_tuple - - # Update library hash after successful build - await user_cache.update_library_hash(token, content_type, typed_items) - - return profile, watched_tmdb, watched_imdb - - async def get_genre_whitelist( - self, - profile: TasteProfile, - content_type: str, - ) -> set[int]: - """ - Get genre whitelist from user's top genres in profile. - - Args: - profile: Taste profile - content_type: Content type (movie/series) - - Returns: - Set of top genre IDs - """ - try: - if not profile: - whitelist = set() - else: - # Get top genres - top_genres = profile.get_top_genres(limit=GENRE_WHITELIST_LIMIT) - whitelist = {int(genre_id) for genre_id, _ in top_genres} - return whitelist - except Exception as e: - logger.warning(f"Failed to build genre whitelist for {content_type}: {e}") - return set() diff --git a/app/services/profile/sampling.py b/app/services/profile/sampling.py index ea444e9..93313d2 100644 --- a/app/services/profile/sampling.py +++ b/app/services/profile/sampling.py @@ -1,127 +1,82 @@ -from typing import Any - -from app.models.scoring import ScoredItem +from app.models.library import LibraryCollection +from app.models.profile import ScoredItem from app.services.profile.constants import SMART_SAMPLING_MAX_ITEMS -from app.services.scoring import ScoringService +from app.services.profile.scoring import ScoringService -class SmartSampler: - """ - Smart sampling for profile building. +def sample_items( + library_items: LibraryCollection, + content_type: str, + scoring_service: ScoringService, + max_items: int = SMART_SAMPLING_MAX_ITEMS, +) -> list[ScoredItem]: + """Sample items for profile building with quota-based selection. Strategy: 1. Always include all loved/liked/added items (strong signals) 2. Fill remaining slots with top watched items by score 3. Limit total to prevent excessive API calls """ - - def __init__(self, scoring_service: ScoringService): - """ - Initialize smart sampler. - - Args: - scoring_service: Service for scoring items - """ - self.scoring_service = scoring_service - - def sample_items( - self, - library_items: dict[str, list[dict[str, Any]]], - content_type: str, - max_items: int = SMART_SAMPLING_MAX_ITEMS, - ) -> list[ScoredItem]: - """ - Sample items for profile building. - - Args: - library_items: Library items dict with 'loved', 'liked', 'watched', 'added' keys - content_type: Content type to filter (movie/series) - max_items: Maximum items to return - - Returns: - List of ScoredItem objects - """ - # Get all items of the requested type - all_items = ( - library_items.get("loved", []) - + library_items.get("liked", []) - + library_items.get("watched", []) - + library_items.get("added", []) - ) - typed_items = [it for it in all_items if it.get("type") == content_type] - - if not typed_items: - return [] - - if len(typed_items) <= max_items: - # score all typed items and return - return [self.scoring_service.process_item(it) for it in typed_items] - - # De-duplicate by ID - unique_items = {} - for it in typed_items: - item_id = it.get("_id") - if item_id: - unique_items[item_id] = it - - # If still within limit after de-duplication - if len(unique_items) <= max_items: - return [self.scoring_service.process_item(it) for it in unique_items.values()] - - # Get set of added item IDs for classification - added_item_ids = {it.get("_id") for it in library_items.get("added", [])} - - # Separate items into pools and score them - loved_liked_pool = [] - added_pool = [] - watched_pool = [] - - for it in unique_items.values(): - scored = self.scoring_service.process_item(it) - if scored.source_type in ["loved", "liked"]: - loved_liked_pool.append(scored) - elif it.get("_id") in added_item_ids: - added_pool.append(scored) - else: - watched_pool.append(scored) - - # Sort pools by score to ensure we take the most relevant items first - # If we sort this, we will get high scoring items, but if we don't sort this, - # we will get recent items. Maybe recent is good? I think yeah. Lets do that... - # it will likely by almost similar but not confirmed. - # loved_liked_pool.sort(key=lambda x: x.score, reverse=True) - # added_pool.sort(key=lambda x: x.score, reverse=True) - # watched_pool.sort(key=lambda x: x.score, reverse=True) - - # Step 1: Fill quotas - final_scored_items: list[ScoredItem] = [] - used_ids: set[str] = set() - - loved_quota = int(max_items * 0.40) - added_quota = int(max_items * 0.20) - watched_quota = max_items - loved_quota - added_quota - - # Add initial quotas - for pool, quota in [ - (loved_liked_pool, loved_quota), - (added_pool, added_quota), - (watched_pool, watched_quota), - ]: - for scored in pool[:quota]: - final_scored_items.append(scored) - used_ids.add(scored.item.id) - - # Step 2: Backfill if we have remaining slots - remaining_slots = max_items - len(final_scored_items) - if remaining_slots > 0: - # Priority for backfill: Loved > Added > Watched - for pool in [loved_liked_pool, added_pool, watched_pool]: - for scored in pool: - if remaining_slots <= 0: - break - if scored.item.id not in used_ids: - final_scored_items.append(scored) - used_ids.add(scored.item.id) - remaining_slots -= 1 - - return final_scored_items + typed_items = [it for it in library_items.all_items() if it.type == content_type] + + if not typed_items: + return [] + + if len(typed_items) <= max_items: + return [scoring_service.process_item(it) for it in typed_items] + + # De-duplicate by ID + unique_items: dict[str, any] = {} + for it in typed_items: + if it.id: + unique_items[it.id] = it + + if len(unique_items) <= max_items: + return [scoring_service.process_item(it) for it in unique_items.values()] + + added_item_ids = {it.id for it in library_items.added} + + # Separate into pools and score + loved_liked_pool: list[ScoredItem] = [] + added_pool: list[ScoredItem] = [] + watched_pool: list[ScoredItem] = [] + + for it in unique_items.values(): + scored = scoring_service.process_item(it) + if scored.source_type in ["loved", "liked"]: + loved_liked_pool.append(scored) + elif it.id in added_item_ids: + added_pool.append(scored) + else: + watched_pool.append(scored) + + # Fill quotas + final: list[ScoredItem] = [] + used_ids: set[str] = set() + + loved_quota = int(max_items * 0.40) + added_quota = int(max_items * 0.20) + watched_quota = max_items - loved_quota - added_quota + + for pool, quota in [ + (loved_liked_pool, loved_quota), + (added_pool, added_quota), + (watched_pool, watched_quota), + ]: + for scored in pool[:quota]: + final.append(scored) + used_ids.add(scored.item.id) + + # Backfill remaining slots (priority: Loved > Added > Watched) + remaining = max_items - len(final) + if remaining > 0: + for pool in [loved_liked_pool, added_pool, watched_pool]: + for scored in pool: + if remaining <= 0: + break + if scored.item.id not in used_ids: + final.append(scored) + used_ids.add(scored.item.id) + remaining -= 1 + + return final diff --git a/app/services/profile/scorer.py b/app/services/profile/scorer.py index e67efe8..fc906ff 100644 --- a/app/services/profile/scorer.py +++ b/app/services/profile/scorer.py @@ -1,6 +1,6 @@ from typing import Any -from app.models.taste_profile import TasteProfile +from app.models.profile import TasteProfile from app.services.profile.constants import ( FEATURE_WEIGHT_COUNTRY, FEATURE_WEIGHT_CREATOR, diff --git a/app/services/scoring.py b/app/services/profile/scoring.py similarity index 84% rename from app/services/scoring.py rename to app/services/profile/scoring.py index 7cfdc6c..a8e50ea 100644 --- a/app/services/scoring.py +++ b/app/services/profile/scoring.py @@ -3,7 +3,8 @@ from loguru import logger -from app.models.scoring import ScoredItem, StremioLibraryItem +from app.models.library import StremioLibraryItem +from app.models.profile import ScoredItem class ScoringService: @@ -22,12 +23,11 @@ class ScoringService: WEIGHT_EXPLICIT_RATING = 0.35 ADDED_TO_LIBRARY_WEIGHT = 0.08 - def process_item(self, raw_item: dict) -> ScoredItem: + def process_item(self, raw_item: dict | StremioLibraryItem) -> ScoredItem: """ - Process a raw Stremio item dictionary into a ScoredItem. + Process a Stremio item (dict or typed model) into a ScoredItem. """ - # Convert dict to Pydantic model for validation and typing - item = StremioLibraryItem(**raw_item) + item = raw_item if isinstance(raw_item, StremioLibraryItem) else StremioLibraryItem(**raw_item) score_data = self._calculate_score_components(item) @@ -40,28 +40,6 @@ def process_item(self, raw_item: dict) -> ScoredItem: source_type="loved" if item.is_loved else ("liked" if item.is_liked else "watched"), ) - def calculate_score( - self, - item: dict | StremioLibraryItem, - is_loved: bool = False, - is_liked: bool = False, - ) -> float: - """ - Backwards compatible method to just get the float score. - Accepts either a raw dict or a StremioLibraryItem. - """ - if isinstance(item, dict): - # Temporarily inject flags if passed separately (legacy support) - if "_is_loved" not in item: - item["_is_loved"] = is_loved - if "_is_liked" not in item: - item["_is_liked"] = is_liked - model_item = StremioLibraryItem(**item) - else: - model_item = item - - return self._calculate_score_components(model_item)["final_score"] - def _calculate_score_components(self, item: StremioLibraryItem) -> dict: """Internal logic to calculate score components.""" state = item.state @@ -160,9 +138,6 @@ def _calculate_score_components(self, item: StremioLibraryItem) -> dict: added_to_library_score = 0.0 if not item.temp and not item.removed: added_to_library_score = 100.0 - # if item.removed: - # # should we penalize for removed items? - # added_to_library_score = -50.0 # Calculate Final Score final_score = ( diff --git a/app/services/profile/service.py b/app/services/profile/service.py new file mode 100644 index 0000000..6f92e85 --- /dev/null +++ b/app/services/profile/service.py @@ -0,0 +1,528 @@ +from typing import Any + +from loguru import logger + +from app.core.settings import UserSettings +from app.models.history import WatchHistory, WatchHistoryItem +from app.models.library import LibraryCollection, StremioLibraryItem, StremioState +from app.models.profile import ScoredItem, TasteProfile +from app.services.profile.builder import ProfileBuilder +from app.services.profile.sampling import sample_items +from app.services.profile.scoring import ScoringService +from app.services.profile.vectorizer import ItemVectorizer +from app.services.recommendation.filtering import RecommendationFiltering +from app.services.stremio.library import stremio_library_to_watch_history +from app.services.tmdb.service import get_tmdb_service +from app.services.user_cache import user_cache + + +def _watch_history_item_to_scored(item: WatchHistoryItem) -> ScoredItem: + """Convert a WatchHistoryItem to a ScoredItem for the existing vectorizer pipeline.""" + state_kwargs: dict[str, Any] = {} + if item.last_watched: + state_kwargs["lastWatched"] = item.last_watched + state_kwargs["timesWatched"] = item.watch_count + + if item.completion < 1.0: + state_kwargs["duration"] = 6000 + state_kwargs["timeWatched"] = int(6000 * item.completion) + else: + state_kwargs["timesWatched"] = max(item.watch_count, 1) + state_kwargs["flaggedWatched"] = 1 + + state = StremioState(**state_kwargs) + + is_loved = item.rating is not None and item.rating >= 9.0 + is_liked = item.rating is not None and 7.0 <= item.rating < 9.0 + + lib_item = StremioLibraryItem( + _id=item.imdb_id, + type=item.type, + name=item.name, + state=state, + temp=False, + removed=False, + _is_loved=is_loved, + _is_liked=is_liked, + ) + + source_type = "loved" if is_loved else ("liked" if is_liked else "watched") + + return ScoredItem( + item=lib_item, + score=50.0, + completion_rate=item.completion, + is_rewatched=item.watch_count > 1, + is_recent=False, + source_type=source_type, + ) + + +class ProfileService: + """Builds, updates, caches, and exposes user taste profiles.""" + + def __init__(self, language: str = "en-US", tmdb_api_key: str | None = None): + self.scoring_service = ScoringService() + tmdb_service = get_tmdb_service(language=language, api_key=tmdb_api_key) + vectorizer = ItemVectorizer(tmdb_service) + self.builder = ProfileBuilder(vectorizer) + + async def build_profile_from_library( + self, + library_items: LibraryCollection, + content_type: str, + stremio_service: Any = None, + auth_key: str | None = None, + ) -> tuple[TasteProfile | None, set[int], set[str]]: + """Build taste profile from library items and get watched sets.""" + watched_imdb, watched_tmdb = await RecommendationFiltering.get_exclusion_sets( + stremio_service, library_items, auth_key + ) + + typed = library_items.for_type(content_type) + if typed.is_empty(): + return None, watched_tmdb, watched_imdb + + sampled = sample_items(typed, content_type, self.scoring_service) + profile = await self.builder.build_profile(sampled, content_type=content_type) + if profile is not None: + profile.source = "stremio" + return profile, watched_tmdb, watched_imdb + + async def build_profile_incremental( + self, + library_items: LibraryCollection, + content_type: str, + token: str, + stremio_service: Any = None, + auth_key: str | None = None, + ) -> tuple[TasteProfile | None, set[int], set[str]]: + """Build profile incrementally if possible, fallback to full rebuild.""" + watched_imdb, watched_tmdb = await RecommendationFiltering.get_exclusion_sets( + stremio_service, library_items, auth_key + ) + + typed = library_items.for_type(content_type) + typed_items = typed.all_items() + + if not typed_items: + return None, watched_tmdb, watched_imdb + + try: + library_changed = await user_cache.has_library_changed(token, content_type, typed_items) + + if not library_changed: + existing_profile = await user_cache.get_profile(token, content_type) + if existing_profile: + return existing_profile, watched_tmdb, watched_imdb + + existing_profile = await user_cache.get_profile(token, content_type) + + if existing_profile: + processed_ids = existing_profile.processed_items + current_ids = {it.id for it in typed_items} + is_legacy = not processed_ids and (existing_profile.genre_scores or existing_profile.director_scores) + + if not processed_ids.issubset(current_ids) or is_legacy: + reason = "Legacy profile detected" if is_legacy else "Items removed from library" + logger.debug(f"[{token[:8]}...] {reason}, falling back to full rebuild") + else: + new_item_ids = current_ids - processed_ids + + if not new_item_ids: + return existing_profile, watched_tmdb, watched_imdb + + logger.debug(f"[{token[:8]}...] Found {len(new_item_ids)} new items, using incremental update") + + def _is_new(it) -> bool: + item_id = it.id if hasattr(it, "id") else (it.get("_id") or it.get("id")) + return item_id in new_item_ids + + new_library = LibraryCollection( + loved=[it for it in typed.loved if _is_new(it)], + liked=[it for it in typed.liked if _is_new(it)], + watched=[it for it in typed.watched if _is_new(it)], + added=[it for it in typed.added if _is_new(it)], + ) + + sampled = sample_items(new_library, content_type, self.scoring_service) + + if not sampled: + return existing_profile, watched_tmdb, watched_imdb + + updated_profile = await self.builder.update_profile_incrementally( + existing_profile, sampled, content_type=content_type + ) + + await user_cache.update_library_hash(token, content_type, typed_items) + return updated_profile, watched_tmdb, watched_imdb + + except Exception as e: + logger.warning(f"[{token[:8]}...] Incremental update failed, falling back to full rebuild: {e}") + + logger.debug(f"[{token[:8]}...] Using full rebuild") + profile, _, _ = await self.build_profile_from_library(library_items, content_type, stremio_service, auth_key) + await user_cache.update_library_hash(token, content_type, typed_items) + return profile, watched_tmdb, watched_imdb + + async def build_profile_from_watch_history( + self, + watch_history: WatchHistory, + content_type: str, + extra_exclusion_imdb: set[str] | None = None, + source: str | None = None, + ) -> tuple[TasteProfile | None, set[str]]: + """Build taste profile from external watch history (Trakt/Simkl).""" + typed_items = [it for it in watch_history.items if it.type == content_type] + if not typed_items: + return None, extra_exclusion_imdb or set() + + scored_items = [_watch_history_item_to_scored(it) for it in typed_items] + profile = await self.builder.build_profile(scored_items, content_type=content_type) + + if profile is not None: + profile.source = source or watch_history.source or "stremio" + + watched_imdb = watch_history.imdb_ids() + if extra_exclusion_imdb: + watched_imdb |= extra_exclusion_imdb + + return profile, watched_imdb + + async def build_and_cache_profile( + self, + token: str, + content_type: str, + library_items: LibraryCollection, + stremio_service: Any = None, + auth_key: str | None = None, + user_settings: UserSettings | None = None, + ) -> tuple[TasteProfile | None, set[int], set[str]]: + """Build profile data and cache the profile and watched sets. + + Dispatches on user_settings.watch_history_source: uses Trakt or Simkl + when the user connected those, otherwise the Stremio library. + """ + source = user_settings.watch_history_source if user_settings else "stremio" + + # Drop a cached profile that was built from a different source than the + # one the user has currently selected — otherwise switching sources in + # the configure page silently keeps serving the old (wrong) profile. + cached = await user_cache.get_profile(token, content_type) + if cached and getattr(cached, "source", "stremio") != source: + logger.info( + f"[{token[:8]}...] Cached profile source '{cached.source}' " + f"!= requested '{source}'; invalidating before rebuild." + ) + await user_cache.invalidate_profile(token, content_type) + await user_cache.invalidate_watched_sets(token, content_type) + + if source in ("trakt", "simkl"): + profile, watched_tmdb, watched_imdb = await self._build_from_external_source( + source, user_settings, content_type, library_items, token=token + ) + else: + profile, watched_tmdb, watched_imdb = await self.build_profile_incremental( + library_items, + content_type, + token, + stremio_service, + auth_key, + ) + + await user_cache.set_profile_and_watched_sets(token, content_type, profile, watched_tmdb, watched_imdb) + return profile, watched_tmdb, watched_imdb + + async def fetch_external_watch_history( + self, + source: str, + user_settings: UserSettings | None, + token: str | None = None, + ) -> tuple[WatchHistory | None, bool, bool]: + """Fetch watch history from Trakt or Simkl with token refresh + revoke handling. + + Returns (history, token_missing, token_revoked). On any non-auth failure + history is None and both flags are False — caller decides whether to + fall back. token_revoked=True implies the stored credential has been + cleared from the user record by `_clear_revoked_token`. + """ + import httpx + + watch_history: WatchHistory | None = None + token_missing = False + token_revoked = False + + if source == "trakt": + if user_settings and user_settings.trakt_access_token: + # Refresh proactively when within 7 days of expiry, then fetch. + # On a 401 from get_history, attempt one reactive refresh + retry + # before giving up — covers cases where expires_at was missing + # or the server clock skewed past it. + access_token, _ = await self._ensure_trakt_token_fresh(token, user_settings) + + from app.services.trakt import trakt_service + + try: + watch_history = await trakt_service.get_history(access_token) + except httpx.HTTPStatusError as e: + if e.response.status_code == 401 and user_settings.trakt_refresh_token and token: + logger.info(f"[{token[:8]}...] Trakt 401 on get_history; attempting reactive refresh.") + refreshed = await self._refresh_trakt_token(token, user_settings.trakt_refresh_token) + if refreshed: + try: + watch_history = await trakt_service.get_history(refreshed) + except httpx.HTTPStatusError as retry_e: + if retry_e.response.status_code in (401, 403): + token_revoked = True + logger.error( + f"Trakt token still rejected after refresh (HTTP " + f"{retry_e.response.status_code}). Clearing stored token." + ) + else: + logger.error( + f"Trakt history fetch failed after refresh (HTTP " + f"{retry_e.response.status_code}: {retry_e})." + ) + watch_history = None + else: + token_revoked = True + logger.error("Trakt refresh failed; clearing stored token. User must reconnect Trakt.") + elif e.response.status_code in (401, 403): + token_revoked = True + logger.error( + f"Trakt token rejected (HTTP {e.response.status_code}). " + "Clearing stored token; user must reconnect Trakt." + ) + else: + logger.error( + f"Trakt history fetch failed (HTTP {e.response.status_code}: {e}). " + "Falling back to Stremio library." + ) + watch_history = None + except Exception as e: + logger.error( + f"Trakt history fetch failed ({type(e).__name__}: {e}). Falling back to Stremio library." + ) + watch_history = None + else: + token_missing = True + elif source == "simkl": + if user_settings and user_settings.simkl_access_token: + from app.core.config import settings as app_settings + from app.services.simkl import simkl_service + + try: + watch_history = await simkl_service.get_history( + user_settings.simkl_access_token, + app_settings.SIMKL_CLIENT_ID or "", + ) + except httpx.HTTPStatusError as e: + if e.response.status_code in (401, 403): + token_revoked = True + logger.error( + f"Simkl token rejected (HTTP {e.response.status_code}). " + "Clearing stored token; user must reconnect Simkl." + ) + else: + logger.error( + f"Simkl history fetch failed (HTTP {e.response.status_code}: {e}). " + "Falling back to Stremio library." + ) + watch_history = None + except Exception as e: + logger.error( + f"Simkl history fetch failed ({type(e).__name__}: {e}). Falling back to Stremio library." + ) + watch_history = None + else: + token_missing = True + + if token_missing: + logger.error( + f"watch_history_source='{source}' but no {source}_access_token in user settings. " + "Falling back to Stremio library — the user likely needs to redo OAuth." + ) + + if token_revoked and token: + await self._clear_revoked_token(token, source) + + return watch_history, token_missing, token_revoked + + async def fetch_external_library( + self, + source: str, + user_settings: UserSettings | None, + token: str | None = None, + ) -> LibraryCollection | None: + """Fetch external watch history and return it as a LibraryCollection. + + Returns None when no history could be fetched (missing/revoked token, + network failure). The collection's `source` field carries the origin + so cache layers can detect a source switch. + """ + from app.services.stremio.library import watch_history_to_library_collection + + history, _, _ = await self.fetch_external_watch_history(source, user_settings, token) + if history is None: + return None + return watch_history_to_library_collection(history) + + async def _build_from_external_source( + self, + source: str, + user_settings: UserSettings | None, + content_type: str, + library: LibraryCollection, + token: str | None = None, + ) -> tuple[TasteProfile | None, set[int], set[str]]: + """Build a profile from an external history source, falling back to the + Stremio library when the external fetch fails or no token is set. + + When the passed-in library was already built from the same external + source (load_user_context handles that), we avoid the duplicate fetch + and read history straight off the library. + """ + watch_history: WatchHistory | None = None + + if library.source == source: + # context layer already pulled from the same source — reuse it. + watch_history = WatchHistory( + items=[ + item + for items in (library.loved, library.liked, library.watched) + for item in ( + WatchHistoryItem( + imdb_id=lib.id, + type=lib.type, + name=lib.name, + rating=(9.0 if lib.is_loved else (7.0 if lib.is_liked else None)), + watch_count=lib.state.timesWatched or (1 if lib.state.flaggedWatched else 0), + completion=( + 1.0 + if lib.state.flaggedWatched or (lib.state.timesWatched or 0) > 0 + else ( + min((lib.state.timeWatched or 0) / lib.state.duration, 1.0) + if lib.state.duration + else 0.0 + ) + ), + last_watched=lib.state.lastWatched, + source=source, + ) + for lib in items + ) + ], + source=source, + ) + else: + watch_history, _, _ = await self.fetch_external_watch_history(source, user_settings, token) + + # An empty WatchHistory still counts as "the source spoke" — only fall back + # on actual failure (None), not on a user with zero history. + effective_source = source + if watch_history is None: + watch_history = stremio_library_to_watch_history(library) + effective_source = "stremio" + + stremio_imdb = library.all_imdb_ids() + profile, watched_imdb = await self.build_profile_from_watch_history( + watch_history, content_type, extra_exclusion_imdb=stremio_imdb, source=effective_source + ) + return profile, set(), watched_imdb + + async def _ensure_trakt_token_fresh(self, token: str | None, user_settings: UserSettings) -> tuple[str, bool]: + """If the Trakt access token is within 7 days of expiry, refresh it. + + Returns (access_token_to_use, was_refreshed). Always returns the best + token we have — even if refresh fails the original is returned so the + caller can still attempt the request and surface the real failure. + """ + import time as _time + + access_token = user_settings.trakt_access_token or "" + expires_at = user_settings.trakt_token_expires_at or 0 + if not (token and user_settings.trakt_refresh_token and expires_at): + return access_token, False + + seven_days = 7 * 24 * 60 * 60 + if _time.time() < expires_at - seven_days: + return access_token, False + + logger.info(f"[{token[:8]}...] Trakt token within refresh window; refreshing proactively.") + refreshed = await self._refresh_trakt_token(token, user_settings.trakt_refresh_token) + if refreshed: + return refreshed, True + return access_token, False + + async def _refresh_trakt_token(self, token: str, refresh_token: str) -> str | None: + """Refresh a Trakt access token and persist the new tokens. + + Returns the new access token on success, None on failure. + """ + import time as _time + + from app.core.config import settings as app_settings + from app.services.token_store import token_store + from app.services.trakt import trakt_service + + redirect_uri = f"{app_settings.HOST_NAME}/auth/trakt/callback" + try: + data = await trakt_service.refresh_token(refresh_token, redirect_uri) + except Exception as e: + logger.warning(f"[{token[:8]}...] Trakt refresh_token call failed: {e}") + return None + + new_access = data.get("access_token") or "" + new_refresh = data.get("refresh_token") or refresh_token + expires_in = int(data.get("expires_in") or 0) + created_at = int(data.get("created_at") or _time.time()) + new_expires_at = created_at + expires_in if expires_in else 0 + if not new_access: + logger.warning(f"[{token[:8]}...] Trakt refresh returned no access_token.") + return None + + try: + credentials = await token_store.get_user_data(token) + if credentials: + settings_dict = credentials.get("settings") or {} + settings_dict["trakt_access_token"] = new_access + settings_dict["trakt_refresh_token"] = new_refresh + settings_dict["trakt_token_expires_at"] = new_expires_at + credentials["settings"] = settings_dict + await token_store.update_user_data(token, credentials) + logger.info(f"[{token[:8]}...] Trakt token refreshed; new expiry={new_expires_at}.") + except Exception as e: + logger.warning(f"[{token[:8]}...] Failed to persist refreshed Trakt token: {e}") + + return new_access + + async def _clear_revoked_token(self, token: str, source: str) -> None: + """Wipe a revoked external-source token from stored credentials. + + Called when Trakt/Simkl returns 401/403 — keeps the user from looping + on a dead token forever. Their /configure page will show the source + as disconnected on next visit so they can reconnect. + """ + from app.services.token_store import token_store + + try: + credentials = await token_store.get_user_data(token) + if not credentials: + return + settings_dict = credentials.get("settings") or {} + mutated = False + if source == "trakt": + for field in ("trakt_access_token", "trakt_refresh_token"): + if settings_dict.get(field): + settings_dict[field] = None + mutated = True + elif source == "simkl": + if settings_dict.get("simkl_access_token"): + settings_dict["simkl_access_token"] = None + mutated = True + if mutated: + credentials["settings"] = settings_dict + await token_store.update_user_data(token, credentials) + logger.info(f"[{token[:8]}...] Cleared revoked {source} credentials.") + except Exception as e: + logger.warning(f"[{token[:8]}...] Failed to clear revoked {source} token: {e}") diff --git a/app/services/profile/vectorizer.py b/app/services/profile/vectorizer.py index 83fd541..8c9fd5b 100644 --- a/app/services/profile/vectorizer.py +++ b/app/services/profile/vectorizer.py @@ -2,7 +2,7 @@ import httpx -from app.models.scoring import ScoredItem +from app.models.profile import ScoredItem from app.services.cinemeta_service import CinemetaService, cinemeta_service from app.services.profile.constants import ( CAST_POSITION_LEAD, @@ -49,11 +49,14 @@ def vectorize_item(metadata: dict[str, Any]) -> dict[str, Any] | None: keywords = [k.get("id") for k in keywords if k.get("id")] - # Extract cast (top 10) + # Top 3 cast only — main + two critical supporting. Tracking deeper into + # the credit list pollutes "favorite cast" with bit-part actors who + # happen to appear across many genre films but who the user wasn't + # actually drawn to. CreatorsService pairs this with a freq>=2 filter. cast = [] credits = metadata.get("credits", {}) or {} cast_list = credits.get("cast", []) or [] - for idx, actor in enumerate(cast_list[:10]): + for idx, actor in enumerate(cast_list[:3]): actor_id = actor.get("id") if isinstance(actor, dict) else actor if actor_id: cast.append(actor_id) @@ -205,7 +208,7 @@ def _extract_cast_with_positions(self, cast: list[Any]) -> list[dict[str, Any]]: return [] result = [] - for idx, cast_item in enumerate(cast[:10]): # Top 10 only + for idx, cast_item in enumerate(cast[:3]): # Top 3 — leads only if isinstance(cast_item, dict): cast_id = cast_item.get("id") position = cast_item.get("position", idx) @@ -289,7 +292,10 @@ async def _extract_runtime_bucket(self, cinemeta_metadata: dict[str, Any]) -> st runtime_str = cinemeta_metadata.get("runtime", "0 min") if runtime_str: - runtime = int(runtime_str.split(" ")[0]) + try: + runtime = int(str(runtime_str).split(" ")[0]) + except (ValueError, TypeError): + runtime = 0 if not runtime or not isinstance(runtime, (int, float)): return None diff --git a/app/services/recommendation/all_based.py b/app/services/recommendation/all_based.py index bb6910b..9bba6b2 100644 --- a/app/services/recommendation/all_based.py +++ b/app/services/recommendation/all_based.py @@ -4,18 +4,18 @@ from loguru import logger from app.core.settings import UserSettings -from app.models.taste_profile import TasteProfile +from app.models.library import LibraryCollection +from app.models.profile import TasteProfile from app.services.profile.scorer import ProfileScorer -from app.services.recommendation.filtering import RecommendationFiltering -from app.services.recommendation.metadata import RecommendationMetadata -from app.services.recommendation.scoring import RecommendationScoring -from app.services.recommendation.utils import ( - content_type_to_mtype, +from app.services.recommendation.filtering import ( + RecommendationFiltering, filter_by_genres, filter_items_by_settings, filter_watched_by_imdb, - resolve_tmdb_id, ) +from app.services.recommendation.metadata import RecommendationMetadata +from app.services.recommendation.scoring import RecommendationScoring +from app.services.recommendation.utils import content_type_to_mtype, resolve_tmdb_id from app.services.simkl import simkl_service from app.services.tmdb.service import TMDBService @@ -34,13 +34,12 @@ def __init__(self, tmdb_service: TMDBService, user_settings: UserSettings | None async def get_recommendations_from_all_items( self, - library_items: dict[str, list[dict[str, Any]]], + library_items: LibraryCollection, content_type: str, watched_tmdb: set[int], watched_imdb: set[str], - whitelist: set[int] | None = None, limit: int = 20, - item_type: str = "loved", # "loved" or "liked" + item_type: str = "loved", profile: TasteProfile | None = None, ) -> list[dict[str, Any]]: """ @@ -58,7 +57,6 @@ async def get_recommendations_from_all_items( content_type: Content type (movie/series) watched_tmdb: Set of watched TMDB IDs watched_imdb: Set of watched IMDB IDs - whitelist: Genre whitelist limit: Number of items to return item_type: "loved" or "liked" profile: Optional profile for scoring (if None, uses popularity only) @@ -66,10 +64,9 @@ async def get_recommendations_from_all_items( Returns: List of recommended items """ - # Get all loved or liked items for the content type - items = library_items.get(item_type, []) + items = getattr(library_items, item_type, []) - typed_items = [it for it in items if it.get("type") == content_type] + typed_items = [it for it in items if it.type == content_type] logger.info(f"Typed items: {len(typed_items)}") @@ -111,7 +108,7 @@ async def get_recommendations_from_all_items( logger.info(f"Fetching TMDB recommendations for {len(top_items)} top items") for item in top_items: - item_id = item.get("_id", "") + item_id = item.id if not item_id: continue tasks.append(self._fetch_recommendations_for_item(item_id, mtype)) @@ -139,15 +136,13 @@ async def get_recommendations_from_all_items( # Filter by genres and watched items excluded_ids = RecommendationFiltering.get_excluded_genre_ids(self.user_settings, content_type) - whitelist = whitelist or set() - filtered = filter_by_genres(candidates, watched_tmdb, whitelist, excluded_ids) + filtered = filter_by_genres(candidates, watched_tmdb, excluded_ids) logger.info(f"Filtered {len(filtered)} candidates") # Score with profile if available scored = [] if profile: - rotation_seed = RecommendationScoring.generate_rotation_seed() # Daily rotation for fresh recommendations for item in filtered: try: final_score = RecommendationScoring.calculate_final_score( @@ -155,13 +150,8 @@ async def get_recommendations_from_all_items( profile=profile, scorer=self.scorer, mtype=mtype, - rotation_seed=rotation_seed, ) - # Apply genre multiplier (if whitelist available) - genre_mult = RecommendationFiltering.get_genre_multiplier(item.get("genre_ids"), whitelist) - final_score *= genre_mult - scored.append((final_score, item)) except Exception as e: logger.debug(f"Failed to score item {item.get('id')}: {e}") @@ -208,7 +198,7 @@ async def _fetch_simkl_candidates(self, top_items: list[dict[str, Any]], mtype: # Extract IMDB IDs imdb_ids = [] for item in top_items: - item_id = item.get("_id", "") + item_id = item.id if item_id and item_id.startswith("tt"): imdb_ids.append(item_id) diff --git a/app/services/recommendation/candidate_sources.py b/app/services/recommendation/candidate_sources.py new file mode 100644 index 0000000..484f6a5 --- /dev/null +++ b/app/services/recommendation/candidate_sources.py @@ -0,0 +1,292 @@ +import asyncio +from datetime import date +from typing import Any + +from loguru import logger + +from app.core.settings import UserSettings +from app.models.library import LibraryCollection +from app.models.profile import TasteProfile +from app.services.profile.sampling import sample_items +from app.services.profile.scoring import ScoringService +from app.services.recommendation.filtering import ( + RecommendationFiltering, + apply_discover_filters, + filter_items_by_settings, +) +from app.services.recommendation.utils import resolve_tmdb_id +from app.services.simkl import simkl_service +from app.services.tmdb.service import TMDBService + + +def _era_to_year_start(era: str) -> int | None: + """Convert era bucket to starting year.""" + era_map = { + "pre-1970s": 1950, + "1970s": 1970, + "1980s": 1980, + "1990s": 1990, + "2000s": 2000, + "2010s": 2010, + "2020s": 2020, + } + return era_map.get(era) + + +class CandidateFetcher: + """Fetches recommendation candidates from multiple sources (TMDB, Simkl, Discover).""" + + def __init__( + self, + tmdb_service: TMDBService, + user_settings: UserSettings | None = None, + scoring_service: ScoringService | None = None, + ): + self.tmdb_service = tmdb_service + self.user_settings = user_settings + self.scoring_service = scoring_service or ScoringService() + + async def fetch_recommendations_from_top_items( + self, + library_items: LibraryCollection, + content_type: str, + mtype: str, + ) -> list[dict[str, Any]]: + """Fetch recommendations from top items (loved/watched/liked/added).""" + top_items = sample_items(library_items, content_type, self.scoring_service, max_items=15) + + candidates = [] + tasks = [] + + for item in top_items: + item = item.item + item_id = item.id + if not item_id: + continue + + tmdb_id = await resolve_tmdb_id(item_id, self.tmdb_service) + if not tmdb_id: + continue + + tasks.append(self.tmdb_service.get_recommendations(tmdb_id, mtype, page=1)) + + logger.info(f"Fetching recommendations from {len(tasks)} top library items") + results = await asyncio.gather(*tasks, return_exceptions=True) + + failed_count = 0 + for res in results: + if isinstance(res, Exception): + failed_count += 1 + logger.debug(f"Recommendation fetch failed: {res}") + continue + candidates.extend(res.get("results", [])) + + if failed_count > 0: + logger.info(f"{failed_count}/{len(tasks)} recommendation fetches failed (expected for items with no recs)") + logger.debug(f"Fetched {len(candidates)} candidates from top items") + + return candidates + + async def fetch_simkl_recommendations( + self, + library_items: LibraryCollection, + content_type: str, + mtype: str, + ) -> list[dict[str, Any]]: + """Fetch recommendations from Simkl for top library items.""" + simkl_api_key = self.user_settings.simkl_api_key if self.user_settings else None + if not simkl_api_key: + logger.warning("Simkl API key not found, skipping Simkl recommendations") + return [] + + top_items = sample_items(library_items, content_type, self.scoring_service, max_items=15) + + imdb_ids = [] + for scored_item in top_items: + item_id = scored_item.item.id + if item_id and item_id.startswith("tt"): + imdb_ids.append(item_id) + + if not imdb_ids: + logger.warning("No valid IMDB IDs found for Simkl recommendations") + return [] + + logger.info(f"Fetching Simkl recommendations for {len(imdb_ids)} items") + + year_min = getattr(self.user_settings, "year_min", None) + year_max = getattr(self.user_settings, "year_max", None) + + try: + candidates = await simkl_service.get_recommendations_batch( + imdb_ids, + mtype, + simkl_api_key, + max_per_item=8, + year_min=year_min, + year_max=year_max, + ) + except Exception as e: + logger.error(f"Error fetching Simkl recommendations: {e}") + return [] + + logger.info(f"Fetched {len(candidates)} candidates from Simkl") + return candidates + + def _add_discover_task(self, tasks: list, mtype: str, without_genres: str | None, **kwargs: Any) -> None: + """Add a discover task to the list of tasks with default parameters.""" + sort_by = RecommendationFiltering.get_sort_by_preference(self.user_settings) + params = { + "sort_by": sort_by, + **kwargs, + } + if without_genres: + params["without_genres"] = without_genres + + params = apply_discover_filters(params, self.user_settings) + tasks.append(self.tmdb_service.get_discover(mtype, **params)) + + async def fetch_discover_with_profile( + self, profile: TasteProfile, content_type: str, mtype: str + ) -> list[dict[str, Any]]: + """Fetch discover results using profile features.""" + excluded_genre_ids = RecommendationFiltering.get_excluded_genre_ids(self.user_settings, content_type) + without_genres = "|".join(str(g) for g in excluded_genre_ids) if excluded_genre_ids else None + + logger.debug(f"Excluded genres for {content_type}: {excluded_genre_ids}") + + top_genres = profile.get_top_genres(limit=5) + top_keywords = profile.get_top_keywords(limit=5) + top_directors = profile.get_top_directors(limit=3) + top_cast = profile.get_top_cast(limit=5) + top_eras = profile.get_top_eras(limit=2) + top_countries = profile.get_top_countries(limit=5) + + candidates = [] + tasks = [] + + if top_genres: + genre_ids = [g[0] for g in top_genres] + self._add_discover_task( + tasks, + mtype, + without_genres, + with_genres="|".join(str(g) for g in genre_ids), + page=1, + ) + + if top_keywords: + keyword_ids = [k[0] for k in top_keywords] + for page in range(1, 3): + self._add_discover_task( + tasks, + mtype, + without_genres, + with_keywords="|".join(str(k) for k in keyword_ids), + page=page, + ) + + if top_directors: + director_ids = [d[0] for d in top_directors] + self._add_discover_task( + tasks, + mtype, + without_genres, + with_crew="|".join(str(d) for d in director_ids), + page=1, + ) + + if top_cast: + cast_ids = [c[0] for c in top_cast] + self._add_discover_task( + tasks, + mtype, + without_genres, + with_cast="|".join(str(c) for c in cast_ids), + page=1, + ) + + if top_eras: + era = top_eras[0][0] + year_start = _era_to_year_start(era) + if year_start: + prefix = "first_air_date" if mtype == "tv" else "primary_release_date" + lte_prefix = ( + date.today().isoformat() if year_start + 9 > date.today().year else f"{year_start + 9}-12-31" + ) + params = { + f"{prefix}.gte": f"{year_start}-01-01", + f"{prefix}.lte": lte_prefix, + "page": 1, + } + self._add_discover_task(tasks, mtype, without_genres, **params) + + if top_countries: + country_codes = [c[0] for c in top_countries] + params = { + "with_origin_country": "|".join(country_codes), + "page": 1, + } + self._add_discover_task(tasks, mtype, without_genres, **params) + + logger.debug(f"Fetching {len(tasks)} discover queries with profile features") + results = await asyncio.gather(*tasks, return_exceptions=True) + + failed_count = 0 + for res in results: + if isinstance(res, Exception): + failed_count += 1 + logger.warning(f"Discover query failed: {res}") + continue + candidates.extend(res.get("results", [])) + + if failed_count > 0: + logger.warning(f"{failed_count}/{len(tasks)} discover queries failed") + logger.debug(f"Fetched {len(candidates)} candidates from discover") + + return candidates + + async def fetch_trending_and_popular(self, content_type: str, mtype: str) -> list[dict[str, Any]]: + """Fetch trending and popular items (for recent items injection).""" + candidates = [] + try: + trending = await self.tmdb_service.get_trending(mtype, time_window="week", page=1) + candidates.extend(trending.get("results", [])) + except Exception as e: + logger.debug(f"Failed to fetch trending: {e}") + + return candidates + + async def fetch_all_candidates( + self, + profile: TasteProfile, + library_items: LibraryCollection, + content_type: str, + mtype: str, + ) -> dict[int, dict[str, Any]]: + """Fetch and merge candidates from all sources, deduped by TMDB ID.""" + all_candidates: dict[int, dict[str, Any]] = {} + + # 1. Fetch recommendations from top items + simkl_api_key = self.user_settings.simkl_api_key if self.user_settings else None + if simkl_api_key: + rec_candidates = await self.fetch_simkl_recommendations(library_items, content_type, mtype) + if not rec_candidates: + logger.info("Simkl returned no results, falling back to TMDB") + rec_candidates = await self.fetch_recommendations_from_top_items(library_items, content_type, mtype) + rec_candidates = filter_items_by_settings(rec_candidates, self.user_settings, simkl=True) + else: + rec_candidates = await self.fetch_recommendations_from_top_items(library_items, content_type, mtype) + rec_candidates = filter_items_by_settings(rec_candidates, self.user_settings) + + for item in rec_candidates: + if item.get("id"): + all_candidates[item["id"]] = item + + # 2. Fetch discover with profile features + discover_candidates = await self.fetch_discover_with_profile(profile, content_type, mtype) + discover_candidates = filter_items_by_settings(discover_candidates, self.user_settings) + for item in discover_candidates: + if item.get("id"): + all_candidates[item["id"]] = item + + return all_candidates diff --git a/app/services/recommendation/catalog_service.py b/app/services/recommendation/catalog_service.py index e277765..4e7eddc 100644 --- a/app/services/recommendation/catalog_service.py +++ b/app/services/recommendation/catalog_service.py @@ -1,4 +1,3 @@ -import random import re import time from typing import Any @@ -7,114 +6,44 @@ from loguru import logger from app.core.config import settings -from app.core.constants import DEFAULT_CATALOG_LIMIT, DEFAULT_MIN_ITEMS +from app.core.constants import DEFAULT_CATALOG_LIMIT from app.core.security import redact_token -from app.core.settings import UserSettings, get_default_settings, resolve_tmdb_api_key -from app.models.taste_profile import TasteProfile +from app.core.settings import UserSettings, resolve_tmdb_api_key +from app.models.library import LibraryCollection +from app.models.profile import TasteProfile from app.services.catalog_updater import catalog_updater -from app.services.profile.integration import ProfileIntegration +from app.services.context import UserContext, extract_settings, load_user_context +from app.services.profile.service import ProfileService from app.services.recommendation.all_based import AllBasedService +from app.services.recommendation.catalog_utils import clean_meta, shuffle_data_if_needed from app.services.recommendation.creators import CreatorsService from app.services.recommendation.item_based import ItemBasedService from app.services.recommendation.theme_based import ThemeBasedService from app.services.recommendation.top_picks import TopPicksService -from app.services.recommendation.utils import pad_to_min -from app.services.stremio.service import StremioBundle from app.services.tmdb.service import get_tmdb_service from app.services.token_store import token_store from app.services.user_cache import user_cache -from app.utils.catalog import cache_profile_and_watched_sets - - -def should_shuffle(user_settings: UserSettings, catalog_id: str) -> bool: - config = next((c for c in user_settings.catalogs if c.id == catalog_id), None) - return getattr(config, "shuffle", False) if config else False - - -def shuffle_data_if_needed( - user_settings: UserSettings, catalog_id: str, data: list[dict[str, Any]] -) -> list[dict[str, Any]]: - if should_shuffle(user_settings, catalog_id): - random.shuffle(data) - return data - - -def _clean_meta(meta: dict) -> dict | None: - """Return a sanitized Stremio meta object without internal fields. - - Keeps only public keys and drops internal scoring/IDs/keywords/cast, etc. - """ - allowed = { - "id", - "type", - "name", - "poster", - "logo", - "background", - "description", - "releaseInfo", - "imdbRating", - "genres", - "runtime", - } - cleaned = {k: v for k, v in meta.items() if k in allowed} - # Drop empty values - cleaned = {k: v for k, v in cleaned.items() if v not in (None, "", [], {}, ())} - - # Normalize IMDb rating to a string with 1 decimal place - rating = cleaned.get("imdbRating") - if rating not in (None, ""): - try: - cleaned["imdbRating"] = f"{float(rating):.1f}" - except (TypeError, ValueError): - # Keep original value if it cannot be parsed - pass - - imdb_id = cleaned.get("id", "") - # if id does not start with tt, return None - if not imdb_id.startswith("tt"): - return None - # Use Metahub logo only when no language-aware logo was set (e.g. from TMDB) - if not cleaned.get("logo"): - cleaned["logo"] = f"https://live.metahub.space/logo/medium/{imdb_id}/img" - return cleaned class CatalogService: - def __init__(self): - pass - async def get_catalog( self, token: str, content_type: str, catalog_id: str ) -> tuple[dict[str, Any], dict[str, Any]]: - """ - Get catalog recommendations. - - Args: - token: User token - content_type: Content type (movie/series) - catalog_id: Catalog ID (watchly.rec, watchly.creators, watchly.theme.*, etc.) - - Returns: - Tuple of (recommendations dict, response headers dict) - """ - # Validate inputs + """Get catalog recommendations.""" self._validate_inputs(token, content_type, catalog_id) - # Prepare response headers - headers: dict[str, Any] = { "Access-Control-Allow-Origin": "*", "Access-Control-Allow-Headers": "*", "Content-Type": "application/json", "Cache-Control": ( - f"public, max-age={settings.CATALOG_CACHE_TTL}," "stale-while-revalidate=3600, stale-if-error=1800" + f"private, max-age={settings.CATALOG_CACHE_TTL},stale-while-revalidate=3600, stale-if-error=1800" ), } - logger.info(f"[{redact_token(token)}...] Fetching catalog for {content_type} with id {catalog_id}") + logger.info(f"[{redact_token(token)}] Fetching catalog for {content_type} with id {catalog_id}") - # Get credentials + # Load credentials (needed for cache check + shuffle settings) credentials = await token_store.get_user_data(token) if not credentials: logger.error("No credentials found for token") @@ -125,95 +54,84 @@ async def get_catalog( # Trigger lazy update if needed if settings.AUTO_UPDATE_CATALOGS: - logger.info(f"[{redact_token(token)}...] Triggering auto update for token") try: await catalog_updater.trigger_update(token, credentials) except Exception as e: - logger.error(f"[{redact_token(token)}...] Failed to trigger auto update: {e}") - # continue with the request even if the auto update fails - pass + logger.error(f"[{redact_token(token)}] Failed to trigger auto update: {e}") - bundle = StremioBundle() - user_settings = None + # Check cache first — avoids auth/library/profile loading on cache hit stale_data = None + cached_result = await user_cache.get_catalog(token, content_type, catalog_id) + + if cached_result: + data, created_at = cached_result + age = int(time.time()) - created_at + + if age < settings.CATALOG_REFRESH_INTERVAL_SECONDS: + logger.debug(f"[{redact_token(token)}] Using cached catalog for {content_type}/{catalog_id}") + user_settings = extract_settings(credentials) + data["metas"] = shuffle_data_if_needed(user_settings, catalog_id, data["metas"]) + return data, headers + + stale_data = data + logger.info( + f"[{redact_token(token)}] Catalog stale (age: {age}s) for " + f"{content_type}/{catalog_id}, refreshing..." + ) + else: + logger.info( + f"[{redact_token(token)}] Catalog not cached for " f"{content_type}/{catalog_id}, building from scratch" + ) + # Cache miss — load full user context + ctx = await load_user_context(token) try: - # get cached catalog - cached_result = await user_cache.get_catalog(token, content_type, catalog_id) - - if cached_result: - data, created_at = cached_result - age = int(time.time()) - created_at - - # If data is fresh enough (within refresh interval), return it - if age < settings.CATALOG_REFRESH_INTERVAL_SECONDS: - logger.debug(f"[{redact_token(token)}...] Using cached catalog for {content_type}/{catalog_id}") - # Try to extract settings from credentials for shuffling, even on cached path - user_settings = self._extract_settings(credentials) - meta_data = data["metas"] - meta_data = shuffle_data_if_needed(user_settings, catalog_id, meta_data) - data["metas"] = meta_data - return data, headers - - # If data is stale, keep it for fallback - stale_data = data - logger.info( - f"[{redact_token(token)}...] Catalog is stale (age: {age}s) for {content_type}/{catalog_id}," - "refreshing..." - ) - else: - logger.info( - f"[{redact_token(token)}...] Catalog not cached for {content_type}/{catalog_id}, building from" - " scratch" - ) - - # Resolve auth and settings - auth_key = await self._resolve_auth(bundle, credentials, token) - user_settings = self._extract_settings(credentials) - - language = user_settings.language if user_settings else "en-US" - - # Try to get cached library items first - library_items = await user_cache.get_library_items(token) + return await self._build_catalog(ctx, content_type, catalog_id, headers, stale_data) + finally: + await ctx.close() - if library_items: - logger.debug(f"[{redact_token(token)}...] Using cached library items") - else: - # Fetch library if not cached - logger.info(f"[{redact_token(token)}...] Library items not cached, fetching from Stremio") - library_items = await bundle.library.get_library_items(auth_key) - # Cache it for future use - await user_cache.set_library_items(token, library_items) + async def _build_catalog( + self, + ctx: UserContext, + content_type: str, + catalog_id: str, + headers: dict[str, Any], + stale_data: dict[str, Any] | None, + ) -> tuple[dict[str, Any], dict[str, Any]]: + """Build fresh catalog content using the loaded user context.""" + try: + services = self._initialize_services(ctx.user_settings) + profile_service: ProfileService = services["profile"] - services = self._initialize_services(language, user_settings) - integration_service: ProfileIntegration = services["integration"] + # Load profile (cached or build fresh) + cached_data = await user_cache.get_profile_and_watched_sets(ctx.token, content_type) - # Try to get cached profile and watched sets - cached_data = await user_cache.get_profile_and_watched_sets(token, content_type) + requested_source = ctx.user_settings.watch_history_source if ctx.user_settings else "stremio" + cached_source = getattr(cached_data[0], "source", "stremio") if cached_data and cached_data[0] else None + if cached_data and cached_source is not None and cached_source != requested_source: + logger.info( + f"[{redact_token(ctx.token)}] Cached profile source '{cached_source}' " + f"!= requested '{requested_source}'; rebuilding." + ) + cached_data = None if cached_data: - # Use cached profile and watched sets profile, watched_tmdb, watched_imdb = cached_data - logger.debug(f"[{redact_token(token)}...] Using cached profile and watched sets for {content_type}") + logger.debug(f"[{redact_token(ctx.token)}] Using cached profile for {content_type}") else: - # Build profile if not cached - logger.info(f"[{redact_token(token)}...] Profile not cached for {content_type}, building from library") - ( - profile, - watched_tmdb, - watched_imdb, - ) = await cache_profile_and_watched_sets( - token, + source = ctx.user_settings.watch_history_source if ctx.user_settings else "stremio" + logger.info( + f"[{redact_token(ctx.token)}] Profile not cached for {content_type}, building from {source}" + ) + profile, watched_tmdb, watched_imdb = await profile_service.build_and_cache_profile( + ctx.token, content_type, - integration_service, - library_items, - bundle, - auth_key, + ctx.library, + ctx.bundle, + ctx.auth_key, + user_settings=ctx.user_settings, ) - whitelist = await integration_service.get_genre_whitelist(profile, content_type) if profile else set() - - # Route to appropriate recommendation service recommendations = await self._get_recommendations( catalog_id=catalog_id, content_type=content_type, @@ -221,62 +139,37 @@ async def get_catalog( profile=profile, watched_tmdb=watched_tmdb, watched_imdb=watched_imdb, - whitelist=whitelist, - library_items=library_items, + library_items=ctx.library, limit=DEFAULT_CATALOG_LIMIT, - user_settings=user_settings, + user_settings=ctx.user_settings, ) - # Pad if needed to meet minimum of 8 items - # # TODO: This is risky because it can fetch too many unrelated items. - if recommendations and len(recommendations) < DEFAULT_MIN_ITEMS: - recommendations = await pad_to_min( - content_type, - recommendations, - DEFAULT_MIN_ITEMS, - services["tmdb"], - user_settings, - watched_tmdb, - watched_imdb, - ) - logger.info(f"Returning {len(recommendations)} items for {content_type}") - # Clean and format metadata - cleaned = [_clean_meta(m) for m in recommendations] - cleaned = [m for m in cleaned if m is not None] - - cleaned = shuffle_data_if_needed(user_settings, catalog_id, cleaned) + cleaned = [m for m in (clean_meta(m) for m in recommendations) if m is not None] + cleaned = shuffle_data_if_needed(ctx.user_settings, catalog_id, cleaned) data = {"metas": cleaned} - # if catalog data is not empty, set the cache with STALE_TTL (7 days) - # This ensures we have fallback data available if the next refresh fails if cleaned: - await user_cache.set_catalog(token, content_type, catalog_id, data, settings.CATALOG_STALE_TTL) + await user_cache.set_catalog(ctx.token, content_type, catalog_id, data, settings.CATALOG_STALE_TTL) return data, headers except Exception as e: - logger.error(f"[{redact_token(token)}...] Failed to generate catalog: {e}") + logger.error(f"[{redact_token(ctx.token)}] Failed to generate catalog: {e}") - # Fallback 1: Return Stale Data if available if stale_data: logger.warning( - f"[{redact_token(token)}...] Serving stale content for {content_type}/{catalog_id} due to error" + f"[{redact_token(ctx.token)}] Serving stale content for " + f"{content_type}/{catalog_id} due to error" ) - # Shuffle stale data too if needed - user_settings = user_settings or self._extract_settings(credentials) meta_data = stale_data.get("metas", []) - meta_data = shuffle_data_if_needed(user_settings, catalog_id, meta_data) + meta_data = shuffle_data_if_needed(ctx.user_settings, catalog_id, meta_data) stale_data["metas"] = meta_data return stale_data, headers - # Fallback 2: Return Empty (prevents 500 error) return {"metas": []}, headers - finally: - await bundle.close() - def _validate_inputs(self, token: str, content_type: str, catalog_id: str) -> None: if not token: raise HTTPException( @@ -288,61 +181,51 @@ def _validate_inputs(self, token: str, content_type: str, catalog_id: str) -> No logger.warning(f"Invalid type: {content_type}") raise HTTPException(status_code=400, detail="Invalid type. Use 'movie' or 'series'") - # Supported IDs supported_base = [ "watchly.rec", "watchly.creators", "watchly.all.loved", "watchly.liked.all", ] - supported_prefixes = ("watchly.theme.", "watchly.loved.", "watchly.watched.") + # watchly.loved.* / watchly.watched.* kept for legacy stored manifests + # — installed Stremio clients may still request these IDs after the + # loved/watched merge until the manifest refreshes. + supported_prefixes = ( + "watchly.theme.", + "watchly.item.", + "watchly.loved.", + "watchly.watched.", + ) if catalog_id not in supported_base and not any(catalog_id.startswith(p) for p in supported_prefixes): logger.warning(f"Invalid id: {catalog_id}") raise HTTPException( status_code=400, detail=( "Invalid id. Supported: 'watchly.rec', 'watchly.creators', " - "'watchly.theme.<params>', 'watchly.all.loved', 'watchly.liked.all'" + "'watchly.theme.<params>', 'watchly.item.<imdb>', " + "'watchly.all.loved', 'watchly.liked.all'" ), ) - async def _resolve_auth(self, bundle: StremioBundle, credentials: dict, token: str) -> str: - auth_key = credentials.get("authKey") - email = credentials.get("email") - password = credentials.get("password") - - # Validate existing auth key - is_valid = False - if auth_key: - try: - await bundle.auth.get_user_info(auth_key) - is_valid = True - except Exception as e: - logger.error(f"Failed to validate auth key during catalog fetch: {e}") - pass - - # Try to refresh if invalid - if not is_valid and email and password: - try: - auth_key = await bundle.auth.login(email, password) - credentials["authKey"] = auth_key - # Update token store with refreshed credentials - await token_store.update_user_data(token, credentials) - except Exception as e: - logger.error(f"Failed to refresh auth key during catalog fetch: {e}") - - if not auth_key: - logger.error("No auth key found during catalog fetch") - raise HTTPException(status_code=401, detail="Stremio session expired. Please reconfigure.") - - return auth_key - - def _extract_settings(self, credentials: dict) -> UserSettings: - settings_dict = credentials.get("settings", {}) - return UserSettings(**settings_dict) if settings_dict else get_default_settings() + def _initialize_services(self, user_settings: UserSettings) -> dict[str, Any]: + tmdb_key = resolve_tmdb_api_key(user_settings) + language = user_settings.language + tmdb_service = get_tmdb_service(language=language, api_key=tmdb_key) + return { + "tmdb": tmdb_service, + "profile": ProfileService(language=language, tmdb_api_key=tmdb_key), + "item": ItemBasedService(tmdb_service, user_settings), + "theme": ThemeBasedService(tmdb_service, user_settings), + "top_picks": TopPicksService(tmdb_service, user_settings), + "creators": CreatorsService(tmdb_service, user_settings), + "all_based": AllBasedService(tmdb_service, user_settings), + } async def _get_trending_fallback( - self, content_type: str, limit: int = 20, user_settings: UserSettings | None = None + self, + content_type: str, + limit: int = 20, + user_settings: UserSettings | None = None, ) -> list[dict[str, Any]]: """Get trending items for new users without profiles.""" from app.services.recommendation.utils import content_type_to_mtype @@ -353,11 +236,9 @@ async def _get_trending_fallback( tmdb_service = get_tmdb_service(language=language, api_key=tmdb_key) try: - # Fetch trending week trending = await tmdb_service.get_trending(mtype, "week") items = trending.get("results", []) - # Enrich metadata from app.services.recommendation.metadata import RecommendationMetadata return await RecommendationMetadata.fetch_batch(tmdb_service, items, content_type, user_settings=None) @@ -365,19 +246,6 @@ async def _get_trending_fallback( logger.warning(f"Failed to fetch trending items: {e}") return [] - def _initialize_services(self, language: str, user_settings: UserSettings) -> dict[str, Any]: - tmdb_key = resolve_tmdb_api_key(user_settings) - tmdb_service = get_tmdb_service(language=language, api_key=tmdb_key) - return { - "tmdb": tmdb_service, - "integration": ProfileIntegration(language=language, tmdb_api_key=tmdb_key), - "item": ItemBasedService(tmdb_service, user_settings), - "theme": ThemeBasedService(tmdb_service, user_settings), - "top_picks": TopPicksService(tmdb_service, user_settings), - "creators": CreatorsService(tmdb_service, user_settings), - "all_based": AllBasedService(tmdb_service, user_settings), - } - async def _get_recommendations( self, catalog_id: str, @@ -386,23 +254,13 @@ async def _get_recommendations( profile: TasteProfile | None, watched_tmdb: set[int], watched_imdb: set[str], - whitelist: set[int], - library_items: dict, + library_items: LibraryCollection, limit: int, user_settings: UserSettings | None = None, ) -> list[dict[str, Any]]: """Route to appropriate recommendation service based on catalog ID.""" - # Item-based recommendations - if any( - catalog_id.startswith(p) - for p in ( - "watchly.loved.", - "watchly.watched.", - ) - ): - # Extract item ID - item_id = re.sub(r"^watchly\.(loved|watched)\.", "", catalog_id) - + if any(catalog_id.startswith(p) for p in ("watchly.item.", "watchly.loved.", "watchly.watched.")): + item_id = re.sub(r"^watchly\.(item|loved|watched)\.", "", catalog_id) item_service: ItemBasedService = services["item"] recommendations = await item_service.get_recommendations_for_item( @@ -411,11 +269,9 @@ async def _get_recommendations( watched_tmdb=watched_tmdb, watched_imdb=watched_imdb, limit=limit, - whitelist=whitelist, ) logger.info(f"Found {len(recommendations)} recommendations for item {item_id}") - # Theme-based recommendations elif catalog_id.startswith("watchly.theme."): theme_service: ThemeBasedService = services["theme"] @@ -426,11 +282,9 @@ async def _get_recommendations( watched_tmdb=watched_tmdb, watched_imdb=watched_imdb, limit=limit, - whitelist=whitelist, ) logger.info(f"Found {len(recommendations)} recommendations for theme {catalog_id}") - # Creators-based recommendations elif catalog_id == "watchly.creators": creators_service: CreatorsService = services["creators"] @@ -447,7 +301,6 @@ async def _get_recommendations( recommendations = await self._get_trending_fallback(content_type, limit, user_settings) logger.info(f"Found {len(recommendations)} recommendations from creators") - # Top picks elif catalog_id == "watchly.rec": if profile: top_picks_service: TopPicksService = services["top_picks"] @@ -465,7 +318,6 @@ async def _get_recommendations( recommendations = await self._get_trending_fallback(content_type, limit, user_settings) logger.info(f"Found {len(recommendations)} top picks for {content_type}") - # Based on what you loved elif catalog_id in ("watchly.all.loved", "watchly.liked.all"): item_type = "loved" if catalog_id == "watchly.all.loved" else "liked" all_based_service: AllBasedService = services["all_based"] @@ -474,7 +326,6 @@ async def _get_recommendations( content_type=content_type, watched_tmdb=watched_tmdb, watched_imdb=watched_imdb, - whitelist=whitelist, limit=limit, item_type=item_type, profile=profile, diff --git a/app/services/recommendation/catalog_utils.py b/app/services/recommendation/catalog_utils.py new file mode 100644 index 0000000..ee5c44b --- /dev/null +++ b/app/services/recommendation/catalog_utils.py @@ -0,0 +1,58 @@ +import random +from typing import Any + +from app.core.settings import UserSettings + + +def should_shuffle(user_settings: UserSettings, catalog_id: str) -> bool: + config = next((c for c in user_settings.catalogs if c.id == catalog_id), None) + return getattr(config, "shuffle", False) if config else False + + +def shuffle_data_if_needed( + user_settings: UserSettings, catalog_id: str, data: list[dict[str, Any]] +) -> list[dict[str, Any]]: + if should_shuffle(user_settings, catalog_id): + random.shuffle(data) + return data + + +def clean_meta(meta: dict) -> dict | None: + """Return a sanitized Stremio meta object without internal fields. + + Keeps only public keys and drops internal scoring/IDs/keywords/cast, etc. + """ + allowed = { + "id", + "type", + "name", + "poster", + "logo", + "background", + "description", + "releaseInfo", + "imdbRating", + "genres", + "runtime", + } + cleaned = {k: v for k, v in meta.items() if k in allowed} + # Drop empty values + cleaned = {k: v for k, v in cleaned.items() if v not in (None, "", [], {}, ())} + + # Normalize IMDb rating to a string with 1 decimal place + rating = cleaned.get("imdbRating") + if rating not in (None, ""): + try: + cleaned["imdbRating"] = f"{float(rating):.1f}" + except (TypeError, ValueError): + # Keep original value if it cannot be parsed + pass + + imdb_id = cleaned.get("id", "") + # if id does not start with tt, return None + if not imdb_id.startswith("tt"): + return None + # Add Metahub logo URL as fallback (used by Stremio) + if not cleaned.get("logo"): + cleaned["logo"] = f"https://live.metahub.space/logo/medium/{imdb_id}/img" + return cleaned diff --git a/app/services/recommendation/creators.py b/app/services/recommendation/creators.py index a8544f8..2facc8b 100644 --- a/app/services/recommendation/creators.py +++ b/app/services/recommendation/creators.py @@ -5,30 +5,72 @@ from loguru import logger from app.core.settings import UserSettings -from app.models.taste_profile import TasteProfile -from app.services.recommendation.filtering import RecommendationFiltering +from app.models.profile import TasteProfile +from app.services.recommendation.filtering import RecommendationFiltering, filter_watched_by_imdb from app.services.recommendation.metadata import RecommendationMetadata -from app.services.recommendation.utils import content_type_to_mtype, filter_watched_by_imdb +from app.services.recommendation.utils import content_type_to_mtype from app.services.tmdb.service import TMDBService +SMALL_LIBRARY_THRESHOLD = 5 +DIRECTOR_LIMIT = 3 +CAST_LIMIT = 3 +MIN_FREQUENCY = 2 + class CreatorsService: - """ - Handles recommendations from favorite creators (directors and cast). - - Strategy: - 1. Build profile from smart-sampled library items - 2. Get top directors and cast from profile - 3. Count raw frequencies to filter single-appearance creators - 4. Prioritize creators with 2+ appearances, fill with single if needed - 5. Fetch recommendations from each creator (fewer pages for single-appearance) - 6. Filter and return results + """Recommendations from creators the user actually returns to. + + A "favorite" creator is someone the user has watched across multiple + items, not just whoever made their last watch. With a sparse library + (1–3 items) every director and lead cast member trivially looks like + a "top creator", which made the old top-N-by-score selection feel + like "more from that one movie I watched". This service filters by + raw appearance frequency (`director_frequency` / `cast_frequency` + persisted on the profile) before fetching: + + * Cast: strict freq >= 2. A movie contributes the top 3 cast, so any + user with two watched items has a real chance of overlap; if no + actor recurs, the cast half of the catalog is empty. + * Directors: freq >= 2 preferred. As a small-library safety net, + when the profile has fewer than 5 processed items and nobody + recurs, fall back to the single highest-scored director so brand + new users still see a row. Once the library grows past the + threshold, "no recurring directors" is honest signal — the + catalog hides itself. + + If neither half qualifies, raise 404 (Stremio will hide the row). """ def __init__(self, tmdb_service: TMDBService, user_settings: UserSettings | None = None): self.tmdb_service: TMDBService = tmdb_service self.user_settings: UserSettings | None = user_settings + @staticmethod + def _select_recurring( + score_pairs: list[tuple[int, float]], + frequency: dict[int, int], + limit: int, + ) -> list[tuple[int, float]]: + """Keep score-sorted creators whose appearance count meets the threshold.""" + return [(cid, score) for cid, score in score_pairs if frequency.get(cid, 0) >= MIN_FREQUENCY][:limit] + + def _select_directors(self, profile: TasteProfile) -> list[tuple[int, float]]: + all_directors = sorted(profile.director_scores.items(), key=lambda kv: kv[1], reverse=True) + recurring = self._select_recurring(all_directors, profile.director_frequency, DIRECTOR_LIMIT) + if recurring: + return recurring + # Small-library fallback: brand-new users haven't had a chance to + # rewatch anyone yet, so seeding from their top-scored director is + # better than an empty catalog. Larger libraries with no recurrence + # legitimately have no "favorite" director — let the catalog hide. + if len(profile.processed_items) < SMALL_LIBRARY_THRESHOLD and all_directors: + return all_directors[:1] + return [] + + def _select_cast(self, profile: TasteProfile) -> list[tuple[int, float]]: + all_cast = sorted(profile.cast_scores.items(), key=lambda kv: kv[1], reverse=True) + return self._select_recurring(all_cast, profile.cast_frequency, CAST_LIMIT) + async def get_recommendations_from_creators( self, profile: TasteProfile, @@ -37,58 +79,43 @@ async def get_recommendations_from_creators( watched_imdb: set[str], limit: int = 20, ) -> list[dict[str, Any]]: - """ - Get recommendations from user's top favorite directors and cast. - - Args: - profile: User taste profile - content_type: Content type (movie/series) - watched_tmdb: Set of watched TMDB IDs - watched_imdb: Set of watched IMDB IDs - limit: Number of recommendations to return - - Returns: - List of recommended items - """ mtype = content_type_to_mtype(content_type) - # Get top 5 directors and cast directly from profile - selected_directors = profile.get_top_directors(limit=5) - selected_cast = profile.get_top_cast(limit=5) + selected_directors = self._select_directors(profile) + selected_cast = self._select_cast(profile) if not selected_directors and not selected_cast: - raise HTTPException(status_code=404, detail="No top directors or cast found") + raise HTTPException(status_code=404, detail="No recurring directors or cast in profile") + + logger.info( + f"Creators catalog: {len(selected_directors)} directors, {len(selected_cast)} cast " + f"(profile has {len(profile.processed_items)} processed items)" + ) - # Fetch recommendations from creators + min_rating, min_votes = RecommendationFiltering.get_quality_thresholds(self.user_settings) all_candidates = {} tasks = [] - # Create tasks for directors (fetch 2 pages each) for dir_id, _ in selected_directors: for page in [1, 2]: - # TV uses with_people, movies use with_crew - if mtype == "tv": - discover_params = {"with_people": str(dir_id), "page": page} - else: - discover_params = {"with_crew": str(dir_id), "page": page} - - # Apply dynamic filters - min_rating, min_votes = RecommendationFiltering.get_quality_thresholds(self.user_settings) - discover_params["vote_count.gte"] = min_votes - discover_params["vote_average.gte"] = min_rating - + # TMDB /discover supports with_crew for both movies and TV; + # with_people is a search-people endpoint param, not valid here. + discover_params = { + "with_crew": str(dir_id), + "page": page, + "vote_count.gte": min_votes, + "vote_average.gte": min_rating, + } tasks.append(self._fetch_discover_page(mtype, discover_params, dir_id, "director")) - # Create tasks for cast (fetch 2 pages each) for cast_id, _ in selected_cast: for page in [1, 2]: - discover_params = {"with_cast": str(cast_id), "page": page} - - # Apply dynamic filters - min_rating, min_votes = RecommendationFiltering.get_quality_thresholds(self.user_settings) - discover_params["vote_count.gte"] = min_votes - discover_params["vote_average.gte"] = min_rating - + discover_params = { + "with_cast": str(cast_id), + "page": page, + "vote_count.gte": min_votes, + "vote_average.gte": min_rating, + } tasks.append(self._fetch_discover_page(mtype, discover_params, cast_id, "cast")) # Execute all tasks in parallel diff --git a/app/services/recommendation/diversity.py b/app/services/recommendation/diversity.py new file mode 100644 index 0000000..9e9e9ca --- /dev/null +++ b/app/services/recommendation/diversity.py @@ -0,0 +1,148 @@ +from collections import defaultdict +from datetime import datetime +from typing import Any + +from app.services.profile.constants import TOP_PICKS_CREATOR_CAP, TOP_PICKS_GENRE_CAP +from app.services.recommendation.filtering import RecommendationFiltering +from app.services.recommendation.scoring import RecommendationScoring + + +def extract_year(item: dict[str, Any]) -> int | None: + """Extract year from item.""" + release_date = item.get("release_date") or item.get("first_air_date") + if release_date: + try: + return int(str(release_date)[:4]) + except (ValueError, TypeError): + pass + return None + + +def is_recent_release(item: dict[str, Any], threshold: datetime, mtype: str) -> bool: + """Check if item was released within the threshold (e.g., last 3 months).""" + release_date_str = item.get("release_date") if mtype == "movie" else item.get("first_air_date") + if not release_date_str: + return False + + try: + release_date = datetime.strptime(str(release_date_str)[:10], "%Y-%m-%d") + return release_date >= threshold + except (ValueError, TypeError): + return False + + +def year_to_era(year: int) -> str: + """Convert year to era bucket.""" + if year < 1970: + return "pre-1970s" + elif year < 1980: + return "1970s" + elif year < 1990: + return "1980s" + elif year < 2000: + return "1990s" + elif year < 2010: + return "2000s" + elif year < 2020: + return "2010s" + else: + return "2020s" + + +def apply_diversity_caps( + scored_candidates: list[tuple[float, dict[str, Any]]], + limit: int, + mtype: str, + user_settings: Any = None, +) -> list[dict[str, Any]]: + """ + Apply diversity caps to ensure balanced results. + + Caps: + - Genre: max 50% per genre + - Quality: minimum vote_count and rating + """ + result = [] + genre_counts: dict[int, int] = defaultdict(int) + + max_per_genre = int(limit * TOP_PICKS_GENRE_CAP) + + for score, item in scored_candidates: + if len(result) >= limit: + break + + item_id = item.get("id") + if not item_id: + continue + + # Quality threshold + vote_count = item.get("vote_count", 0) + vote_avg = item.get("vote_average", 0) + + min_rating, min_votes = RecommendationFiltering.get_quality_thresholds(user_settings) + + if vote_count < min_votes: + continue + + wr = RecommendationScoring.weighted_rating(vote_avg, vote_count, C=7.2 if mtype == "tv" else 6.8) + if wr < min_rating: + continue + + # Check genre cap (50% max per genre) + genre_ids = item.get("genre_ids", []) + top_genre = genre_ids[0] if genre_ids else None + + if top_genre: + if genre_counts[top_genre] >= max_per_genre: + continue + + result.append(item) + + if top_genre: + genre_counts[top_genre] += 1 + + return result + + +def apply_creator_cap(items: list[dict[str, Any]], limit: int) -> list[dict[str, Any]]: + """ + Apply creator cap (max 2 items per director/actor) after enrichment. + """ + result = [] + creator_counts: dict[int, int] = defaultdict(int) + + for item in items: + if len(result) >= limit: + break + + credits = item.get("credits", {}) or {} + crew = credits.get("crew", []) or [] + cast = credits.get("cast", []) or [] + + # Check director cap + directors = [c.get("id") for c in crew if c.get("job", "").lower() == "director" and c.get("id")] + blocked_by_director = False + for dir_id in directors: + if creator_counts[dir_id] >= TOP_PICKS_CREATOR_CAP: + blocked_by_director = True + break + + # Check cast cap (top 3 only) + top_cast = [c.get("id") for c in cast[:3] if c.get("id")] + blocked_by_cast = False + for cast_id in top_cast: + if creator_counts[cast_id] >= TOP_PICKS_CREATOR_CAP: + blocked_by_cast = True + break + + if blocked_by_director or blocked_by_cast: + continue + + result.append(item) + + for dir_id in directors: + creator_counts[dir_id] += 1 + for cast_id in top_cast: + creator_counts[cast_id] += 1 + + return result diff --git a/app/services/recommendation/filtering.py b/app/services/recommendation/filtering.py index f3f9794..6dd8e2a 100644 --- a/app/services/recommendation/filtering.py +++ b/app/services/recommendation/filtering.py @@ -1,6 +1,11 @@ +from datetime import datetime from typing import Any from urllib.parse import unquote +from app.core.constants import DISCOVERY_SETTINGS +from app.core.settings import DEFAULT_YEAR_MIN, get_current_year +from app.models.library import LibraryCollection + def parse_identifier(identifier: str) -> tuple[str | None, int | None]: """Parse Stremio identifier to extract IMDB ID and TMDB ID.""" @@ -36,31 +41,25 @@ class RecommendationFiltering: @staticmethod async def get_exclusion_sets( stremio_service: Any, - library_data: dict | None = None, + library_data: LibraryCollection | None = None, auth_key: str | None = None, ) -> tuple[set[str], set[int]]: - """ - Fetch library items and build exclusion sets for watched/loved content. - """ + """Build exclusion sets for watched/loved content.""" if library_data is None: if not auth_key: return set(), set() library_data = await stremio_service.library.get_library_items(auth_key) - library_data = library_data or {} + if library_data is None: + return set(), set() - all_items = ( - library_data.get("loved", []) - + library_data.get("watched", []) - + library_data.get("removed", []) - + library_data.get("liked", []) - ) + all_items = library_data.all_items_with_removed() imdb_ids = set() tmdb_ids = set() for item in all_items: - item_id = item.get("_id", "") + item_id = item.id if hasattr(item, "id") else item.get("_id", "") if not item_id: continue @@ -85,6 +84,19 @@ async def get_exclusion_sets( return imdb_ids, tmdb_ids + @staticmethod + def get_library_imdb_ids(library_data: dict | None) -> set[str]: + """Extract all IMDB IDs from Stremio library data.""" + if not library_data: + return set() + imdb_ids: set[str] = set() + for category in ("loved", "liked", "watched", "added", "removed"): + for item in library_data.get(category, []): + item_id = item.get("_id", "") + if item_id.startswith("tt"): + imdb_ids.add(item_id.split(":")[0]) + return imdb_ids + @staticmethod def filter_candidates( candidates: list[dict[str, Any]], watched_imdb: set[str], watched_tmdb: set[int] @@ -173,29 +185,129 @@ def get_excluded_genre_ids(user_settings: Any, content_type: str) -> list[int]: return [int(g) for g in user_settings.excluded_series_genres] return [] - @staticmethod - def get_genre_multiplier(genre_ids: list[int] | None, whitelist: set[int]) -> float: - """Calculate a score multiplier based on genre preference. Blocks animation if not preferred.""" - if not whitelist: - return 1.0 - gids = set(genre_ids or []) - if not gids: - return 1.0 +# --- Standalone filtering functions (moved from utils.py) --- + + +def filter_watched_by_imdb(enriched: list[dict[str, Any]], watched_imdb: set[str]) -> list[dict[str, Any]]: + """Filter enriched items by watched IMDB IDs.""" + final = [] + for item in enriched: + if item.get("id") in watched_imdb: + continue + if item.get("_external_ids", {}).get("imdb_id") in watched_imdb: + continue + final.append(item) + return final + + +def filter_by_genres( + items: list[dict[str, Any]], + watched_tmdb: set[int], + excluded_ids: list[int] | None = None, +) -> list[dict[str, Any]]: + """Filter items by watched set and excluded genres.""" + excluded_ids = excluded_ids or [] + filtered = [] + + for item in items: + item_id = item.get("id") + if not item_id or item_id in watched_tmdb: + continue + genre_ids = item.get("genre_ids", []) + if excluded_ids and any(gid in excluded_ids for gid in genre_ids): + continue + filtered.append(item) + + return filtered - # If it has at least one preferred genre, full score - if gids & whitelist: - return 1.0 - # Otherwise, soft penalty to prioritize whitelist items without blocking variety - return 0.4 +def build_discover_params(user_settings: Any) -> dict[str, Any]: + """Build TMDB discover API parameters based on user settings.""" + params: dict[str, Any] = {} + if not user_settings: + return params - @staticmethod - def passes_top_genre_whitelist(genre_ids: list[int] | None, whitelist: set[int]) -> bool: - """Check if an item's genres match the user's top genre whitelist (Softened).""" - if not whitelist: - return True - gids = set(genre_ids or []) - if not gids: - return True - return True + current_date = datetime.now() + current_year = get_current_year() + + year_min = getattr(user_settings, "year_min", DEFAULT_YEAR_MIN) + year_max = getattr(user_settings, "year_max", current_year) + + for prefix in ["primary_release_date", "first_air_date"]: + params[f"{prefix}.gte"] = f"{year_min}-01-01" + if year_max >= current_year: + params[f"{prefix}.lte"] = current_date.strftime("%Y-%m-%d") + else: + params[f"{prefix}.lte"] = f"{year_max}-12-31" + + return params + + +def apply_discover_filters(params: dict[str, Any], user_settings: Any) -> dict[str, Any]: + """Merge discover params with global user settings (years, popularity).""" + if not user_settings: + return params + + global_params = build_discover_params(user_settings) + params = {**global_params, **params} + + min_rating, min_votes = RecommendationFiltering.get_quality_thresholds(user_settings) + + if "vote_count.gte" not in params: + params["vote_count.gte"] = min_votes + if "vote_average.gte" not in params: + params["vote_average.gte"] = min_rating + + return params + + +def filter_items_by_settings( + items: list[dict[str, Any]], user_settings: Any, simkl: bool = False +) -> list[dict[str, Any]]: + """Filter items post-fetch based on user settings (years, popularity).""" + if not user_settings: + return items + + year_min = getattr(user_settings, "year_min", DEFAULT_YEAR_MIN) + year_max = getattr(user_settings, "year_max", get_current_year()) + pop_pref = getattr(user_settings, "popularity", "balanced") + + # Hoist out of the per-item loop: this lookup doesn't depend on the item. + # If pop_pref has no mapping, fall back to no popularity filtering rather + # than dropping every item. + params = DISCOVERY_SETTINGS.get(pop_pref) or {} + + ops = { + "gte": lambda x, y: x >= y, + "lte": lambda x, y: x <= y, + } + + filtered = [] + for item in items: + release_date = item.get("release_date") or item.get("first_air_date") or item.get("released") + if release_date: + try: + year = int(release_date.split("-")[0]) + if year < year_min or year > year_max: + continue + except (ValueError, IndexError): + pass + + passes_all = True + for param in params: + t_param, param_ops = param.split(".") + param_operator = ops.get(param_ops) + if not param_operator: + continue + if simkl and t_param == "popularity": + continue + item_value = item.get(t_param) + if item_value is None or not param_operator(item_value, params[param]): + passes_all = False + break + + if passes_all: + filtered.append(item) + + return filtered diff --git a/app/services/recommendation/item_based.py b/app/services/recommendation/item_based.py index 85bee27..0f71f67 100644 --- a/app/services/recommendation/item_based.py +++ b/app/services/recommendation/item_based.py @@ -3,15 +3,14 @@ from loguru import logger -from app.services.recommendation.filtering import RecommendationFiltering -from app.services.recommendation.metadata import RecommendationMetadata -from app.services.recommendation.utils import ( - content_type_to_mtype, +from app.services.recommendation.filtering import ( + RecommendationFiltering, filter_by_genres, filter_items_by_settings, filter_watched_by_imdb, - resolve_tmdb_id, ) +from app.services.recommendation.metadata import RecommendationMetadata +from app.services.recommendation.utils import content_type_to_mtype, resolve_tmdb_id from app.services.simkl import simkl_service from app.services.tmdb.service import TMDBService @@ -32,7 +31,6 @@ async def get_recommendations_for_item( watched_tmdb: set[int] | None = None, watched_imdb: set[str] | None = None, limit: int = 20, - whitelist: set[int] | None = None, ) -> list[dict[str, Any]]: """ Get recommendations for a specific item. @@ -41,8 +39,7 @@ async def get_recommendations_for_item( 1. Fetch similar + recommendations from TMDB (2 pages each) 2. Filter watched items 3. Filter excluded genres - 4. Apply genre whitelist - 5. Return top N + 4. Return top N Args: item_id: Item ID (tt... or tmdb:...) @@ -67,7 +64,17 @@ async def get_recommendations_for_item( # Fetch candidates (similar + recommendations, 2 pages each) tasks = [self._fetch_candidates_from_simkl(item_id, mtype), self._fetch_candidates(tmdb_id, mtype)] - simkl_candidates, candidates = await asyncio.gather(*tasks) + simkl_result, tmdb_result = await asyncio.gather(*tasks, return_exceptions=True) + if isinstance(simkl_result, Exception): + logger.warning(f"item-based simkl candidate fetch failed for {item_id}: {simkl_result}") + simkl_candidates: list = [] + else: + simkl_candidates = simkl_result + if isinstance(tmdb_result, Exception): + logger.warning(f"item-based tmdb candidate fetch failed for {item_id}: {tmdb_result}") + candidates: list = [] + else: + candidates = tmdb_result # Apply global settings filter (years, popularity) candidates = filter_items_by_settings(candidates, self.user_settings) @@ -77,7 +84,7 @@ async def get_recommendations_for_item( # Filter by genres and watched items excluded_ids = RecommendationFiltering.get_excluded_genre_ids(self.user_settings, content_type) - filtered = filter_by_genres(candidates, watched_tmdb, whitelist, excluded_ids) + filtered = filter_by_genres(candidates, watched_tmdb, excluded_ids) # Enrich metadata enriched = await RecommendationMetadata.fetch_batch( @@ -130,11 +137,10 @@ async def fetch_and_combine(fetch_method, source_name, pages: list[int] = [1, 2, if not combined or len(combined) < 30: await fetch_and_combine(self.tmdb_service.get_similar, "similar") - # apply filter and check - filtered = filter_items_by_settings(combined.values(), self.user_settings) - - if not filtered or len(filtered) < 30: - # fetch more similar items if there are less than 30 items after user_settings filter + # If the post-settings filter produces fewer than 30 candidates, pull + # more pages of similar before returning so the caller has headroom. + if len(filter_items_by_settings(combined.values(), self.user_settings)) < 30: await fetch_and_combine(self.tmdb_service.get_similar, "similar", pages=[4, 5, 6]) + # Caller re-applies filter_items_by_settings, so return the merged set. return list(combined.values()) diff --git a/app/services/recommendation/metadata.py b/app/services/recommendation/metadata.py index 764c788..0ffad1c 100644 --- a/app/services/recommendation/metadata.py +++ b/app/services/recommendation/metadata.py @@ -28,11 +28,7 @@ def extract_year(item: dict[str, Any]) -> int | None: @classmethod async def format_for_stremio( - cls, - details: dict[str, Any], - media_type: str, - user_settings: Any = None, - logo_url: str | None = None, + cls, details: dict[str, Any], media_type: str, user_settings: Any = None, logo_url: str | None = None ) -> dict[str, Any] | None: """Format TMDB details into Stremio metadata object.""" external_ids = details.get("external_ids", {}) @@ -156,7 +152,8 @@ async def _fetch_one(tid: int): return None tasks = [_fetch_one(it.get("id")) for it in valid_items] - details_list = await asyncio.gather(*tasks) + details_list = await asyncio.gather(*tasks, return_exceptions=True) + details_list = [d for d in details_list if d and not isinstance(d, Exception)] language = getattr(user_settings, "language", None) or "en-US" mt = "movie" if media_type == "movie" else "tv" diff --git a/app/services/recommendation/rotation.py b/app/services/recommendation/rotation.py deleted file mode 100644 index 4f36ee0..0000000 --- a/app/services/recommendation/rotation.py +++ /dev/null @@ -1,31 +0,0 @@ -"""Daily rotation utilities for fresh recommendations.""" - -import random - - -class DailyRotation: - """Utilities for rotating recommendations daily while maintaining quality.""" - - @staticmethod - def rotate_items(items: list, seed: str) -> list: - """ - Rotate the items daily. - - This provides freshness while maintaining quality: - - shuffled deterministically based on daily seed - - User sees different content every day without sacrificing quality - - Args: - items: List of items - seed: Daily rotation seed (changes daily) - - Returns: - Rotated list with items shuffled deterministically - """ - - # Deterministically shuffle items based on daily seed - rng = random.Random(seed) - shuffled_items = items.copy() - rng.shuffle(shuffled_items) - - return shuffled_items diff --git a/app/services/recommendation/scoring.py b/app/services/recommendation/scoring.py index 57ed57b..1abddcd 100644 --- a/app/services/recommendation/scoring.py +++ b/app/services/recommendation/scoring.py @@ -1,4 +1,3 @@ -import hashlib import math from collections.abc import Callable from typing import Any @@ -28,34 +27,6 @@ def normalize(value: float, min_v: float = 0.0, max_v: float = 10.0) -> float: return 0.0 return max(0.0, min(1.0, (value - min_v) / (max_v - min_v))) - @staticmethod - def stable_epsilon(tmdb_id: int, seed: str) -> float: - """Generate a stable tiny epsilon to break ties deterministically.""" - if not seed: - return 0.0 - h = hashlib.md5(f"{seed}:{tmdb_id}".encode()).hexdigest() - eps = int(h[-6:], 16) % 1000 - return eps / 1_000_000.0 - - @staticmethod - def generate_rotation_seed(token: str | None = None) -> str: - """ - Generate a daily rotation seed for deterministic but fresh recommendations. - - Args: - token: Optional user token for per-user variation. - If None, uses a global seed (same for all users on same day). - - Returns: - A seed string like "abc123:2026-01-15" - """ - from datetime import date - - today = date.today().isoformat() - if token: - return f"{token}:{today}" - return f"global:{today}" - @staticmethod def get_recency_multiplier_fn( profile: Any, candidate_decades: set[int] | None = None @@ -129,8 +100,7 @@ def calculate_final_score( profile: Any, scorer: Any, mtype: str, - rotation_seed: str | None = None, - ) -> float: # noqa: E501 + ) -> float: """ Calculate final recommendation score combining profile similarity and quality. @@ -139,12 +109,9 @@ def calculate_final_score( profile: User taste profile scorer: ProfileScorer instance mtype: Media type (movie/tv) to determine minimum rating - rotation_seed: Optional seed for daily rotation (e.g., "token:2026-01-15"). - When provided, adds a tiny epsilon for deterministic tie-breaking - that changes daily, making recommendations feel fresh. Returns: - Final combined score (0-1 range, with optional epsilon for rotation) + Final combined score (0-1 range) """ # Score with profile profile_score = scorer.score_item(item, profile) @@ -168,11 +135,4 @@ def calculate_final_score( # light boost for high-confidence items (no penalties!) vote_count = item.get("vote_count", 0) popularity = item.get("popularity", 0) - final_score = RecommendationScoring.apply_quality_adjustments(base_score, wr, vote_count, popularity) - # Apply daily rotation epsilon for tie-breaking (if seed provided) - if rotation_seed: - tmdb_id = item.get("id", 0) - epsilon = RecommendationScoring.stable_epsilon(tmdb_id, rotation_seed) - final_score += epsilon - - return final_score + return RecommendationScoring.apply_quality_adjustments(base_score, wr, vote_count, popularity) diff --git a/app/services/recommendation/theme_based.py b/app/services/recommendation/theme_based.py index 2838b94..f3f9646 100644 --- a/app/services/recommendation/theme_based.py +++ b/app/services/recommendation/theme_based.py @@ -3,7 +3,7 @@ from loguru import logger -from app.models.taste_profile import TasteProfile +from app.models.profile import TasteProfile from app.services.profile.constants import ( RUNTIME_BUCKET_MEDIUM_MAX_MOVIE, RUNTIME_BUCKET_MEDIUM_MAX_SERIES, @@ -11,15 +11,14 @@ RUNTIME_BUCKET_SHORT_MAX_SERIES, ) from app.services.profile.scorer import ProfileScorer -from app.services.recommendation.filtering import RecommendationFiltering -from app.services.recommendation.metadata import RecommendationMetadata -from app.services.recommendation.scoring import RecommendationScoring -from app.services.recommendation.utils import ( +from app.services.recommendation.filtering import ( + RecommendationFiltering, apply_discover_filters, - content_type_to_mtype, - filter_by_genres, filter_watched_by_imdb, ) +from app.services.recommendation.metadata import RecommendationMetadata +from app.services.recommendation.scoring import RecommendationScoring +from app.services.recommendation.utils import content_type_to_mtype from app.services.tmdb.service import TMDBService @@ -48,7 +47,6 @@ async def get_recommendations_for_theme( watched_tmdb: set[int] | None = None, watched_imdb: set[str] | None = None, limit: int = 20, - whitelist: set[int] | None = None, ) -> list[dict[str, Any]]: """Get recommendations for a role-based theme.""" watched_tmdb = watched_tmdb or set() @@ -148,7 +146,6 @@ async def get_recommendations_for_theme( # 5. Weighted Scoring scored = [] - rotation_seed = RecommendationScoring.generate_rotation_seed() mtype = content_type_to_mtype(content_type) for item in candidates: @@ -162,7 +159,6 @@ async def get_recommendations_for_theme( profile=profile, scorer=self.scorer, mtype=mtype, - rotation_seed=rotation_seed, ) else: base_score = RecommendationScoring.normalize(item.get("vote_average", 0)) @@ -254,9 +250,9 @@ def _axes_to_params(self, axes: dict, content_type: str) -> dict: start_year = int(val) end_year = start_year + 9 - prefix = "first_air_date" if content_type in ("tv", "series") else "primary_release_date" - params[f"{prefix}.gte"] = f"{start_year}-01-01" - params[f"{prefix}.lte"] = f"{end_year}-12-31" + prefix = "first_air_date" if content_type in ("tv", "series") else "primary_release_date" + params[f"{prefix}.gte"] = f"{start_year}-01-01" + params[f"{prefix}.lte"] = f"{end_year}-12-31" except Exception: logger.error("Failed to parse era axis: {}", axes["era"]) pass @@ -387,34 +383,3 @@ async def _fetch_discover_candidates( candidates.extend(res.get("results", [])) return candidates - - def _filter_candidates( - self, - candidates: list[dict[str, Any]], - watched_tmdb: set[int], - whitelist: set[int], - existing_ids: set[int] | None = None, - ) -> list[dict[str, Any]]: - """ - Filter candidates by watched items and genre whitelist. - - Args: - candidates: List of candidate items - watched_tmdb: Set of watched TMDB IDs - whitelist: Set of genre IDs in whitelist - existing_ids: Set of IDs to exclude (for deduplication) - - Returns: - Filtered list of items - """ - existing = existing_ids or set() - # First filter by genres (includes watched_tmdb check) - filtered = filter_by_genres(candidates, watched_tmdb, whitelist, None) - # Then deduplicate - result = [] - for item in filtered: - item_id = item.get("id") - if item_id and item_id not in existing: - result.append(item) - existing.add(item_id) - return result diff --git a/app/services/recommendation/top_picks.py b/app/services/recommendation/top_picks.py index e245598..c27d6e9 100644 --- a/app/services/recommendation/top_picks.py +++ b/app/services/recommendation/top_picks.py @@ -1,36 +1,33 @@ -import asyncio import time -from collections import defaultdict -from datetime import date, datetime from typing import Any from loguru import logger from app.core.constants import DEFAULT_CATALOG_LIMIT, MAX_CATALOG_ITEMS from app.core.settings import UserSettings -from app.models.taste_profile import TasteProfile -from app.services.profile.constants import TOP_PICKS_CREATOR_CAP, TOP_PICKS_GENRE_CAP -from app.services.profile.sampling import SmartSampler +from app.models.library import LibraryCollection +from app.models.profile import TasteProfile from app.services.profile.scorer import ProfileScorer -from app.services.recommendation.filtering import RecommendationFiltering +from app.services.profile.scoring import ScoringService +from app.services.recommendation.candidate_sources import CandidateFetcher +from app.services.recommendation.diversity import apply_diversity_caps +from app.services.recommendation.filtering import filter_watched_by_imdb from app.services.recommendation.metadata import RecommendationMetadata -from app.services.recommendation.rotation import DailyRotation from app.services.recommendation.scoring import RecommendationScoring -from app.services.recommendation.utils import ( - apply_discover_filters, - content_type_to_mtype, - filter_items_by_settings, - filter_watched_by_imdb, - resolve_tmdb_id, -) -from app.services.scoring import ScoringService -from app.services.simkl import simkl_service +from app.services.recommendation.utils import content_type_to_mtype from app.services.tmdb.service import TMDBService class TopPicksService: """ Generates top picks by combining multiple sources and applying diversity caps. + + Orchestrates: + 1. CandidateFetcher — gathers candidates from TMDB/Simkl/Discover + 2. RecommendationScoring — scores candidates against user profile + 3. apply_diversity_caps — ensures balanced genre/quality distribution + 4. RecommendationMetadata — enriches with full details + 5. apply_creator_cap — limits per-director/actor saturation """ def __init__(self, tmdb_service: TMDBService, user_settings: UserSettings | None = None): @@ -38,13 +35,13 @@ def __init__(self, tmdb_service: TMDBService, user_settings: UserSettings | None self.user_settings: UserSettings | None = user_settings self.scorer: ProfileScorer = ProfileScorer() self.scoring_service = ScoringService() - self.smart_sampler = SmartSampler(self.scoring_service) + self.candidate_fetcher = CandidateFetcher(tmdb_service, user_settings, self.scoring_service) async def get_top_picks( self, profile: TasteProfile, content_type: str, - library_items: dict[str, list[dict[str, Any]]], + library_items: LibraryCollection, watched_tmdb: set[int], watched_imdb: set[str], limit: int = DEFAULT_CATALOG_LIMIT, @@ -53,71 +50,27 @@ async def get_top_picks( Get top picks with diversity caps. Strategy: - 1. Fetch recommendations from top 8 library items - 1 page each - 2. Fetch discover with profile features (genres, keywords, cast, crew, era, country) - 3. Merge all candidates (deduped by TMDB ID) - 4. Score with ProfileScorer + Quality - 5. Apply diversity caps (relaxed: 50% genre, 50% era, 15% recent) - 6. Limit to 2x target before enrichment (performance optimization) - 7. Enrich metadata with full details - 8. Apply creator cap and final filters - 9. Return balanced results - - Args: - profile: User taste profile - content_type: Content type (movie/series) - library_items: Library items dict - watched_tmdb: Set of watched TMDB IDs - watched_imdb: Set of watched IMDB IDs - limit: Number of items to return - - Returns: - List of recommended items + 1. Fetch candidates from all sources (TMDB recs, Simkl, Discover) + 2. Filter out watched items + 3. Score with ProfileScorer + Quality + 4. Apply diversity caps + 5. Enrich metadata with full details + 6. Apply creator cap and final filters """ - start_time = time.time() - logger.info(f"Starting top picks generation for {content_type}, target limit={limit}") mtype = content_type_to_mtype(content_type) - all_candidates = {} - - # 1. Fetch recommendations from top items - # Use Simkl if API key available, otherwise fall back to TMDB - simkl_api_key = self.user_settings.simkl_api_key if self.user_settings else None - if simkl_api_key: - rec_candidates = await self._fetch_simkl_recommendations(library_items, content_type, mtype) - if not rec_candidates: - # Fallback to TMDB if Simkl returns nothing - logger.info("Simkl returned no results, falling back to TMDB") - rec_candidates = await self._fetch_recommendations_from_top_items(library_items, content_type, mtype) - # filter items - rec_candidates = filter_items_by_settings(rec_candidates, self.user_settings, simkl=True) - else: - rec_candidates = await self._fetch_recommendations_from_top_items(library_items, content_type, mtype) - # filter items - rec_candidates = filter_items_by_settings(rec_candidates, self.user_settings) - for item in rec_candidates: - if item.get("id"): - all_candidates[item["id"]] = item + # 1. Fetch and merge all candidates + all_candidates = await self.candidate_fetcher.fetch_all_candidates(profile, library_items, content_type, mtype) - # 2. Fetch discover with profile features - discover_candidates = await self._fetch_discover_with_profile(profile, content_type, mtype) - # filter by user settings - discover_candidates = filter_items_by_settings(discover_candidates, self.user_settings) - for item in discover_candidates: - if item.get("id"): - all_candidates[item["id"]] = item - - # Filter out watched items + # 2. Filter out watched items filtered_candidates = [item for item in all_candidates.values() if item.get("id") not in watched_tmdb] - logger.info(f"Found {len(filtered_candidates)} candidates after filtering out watched items and user settings") - # Score all candidates with profile + # 3. Score all candidates with profile scored_candidates = [] - rotation_seed = RecommendationScoring.generate_rotation_seed() # Daily rotation for fresh recommendations for item in filtered_candidates: try: final_score = RecommendationScoring.calculate_final_score( @@ -125,493 +78,34 @@ async def get_top_picks( profile=profile, scorer=self.scorer, mtype=mtype, - rotation_seed=rotation_seed, ) scored_candidates.append((final_score, item)) except Exception as e: logger.debug(f"Failed to score item {item.get('id')}: {e}") continue - # Sort by score scored_candidates.sort(key=lambda x: x[0], reverse=True) - logger.info(f"Scored {len(scored_candidates)} candidates.") - # Apply diversity caps - result = self._apply_diversity_caps(scored_candidates, len(scored_candidates), mtype) + # 4. Apply diversity caps (cap to 3x the target so the genre cap is meaningful + # and we still have headroom for the post-enrichment filters) + diversity_target = limit * 3 + result = apply_diversity_caps(scored_candidates, diversity_target, mtype, self.user_settings) logger.info(f"After diversity caps: {len(result)} items") - # Limit before enrichment to avoid timeout (only enrich 3x what we need) - result = result[: limit * 3] - logger.info(f"After diversity caps and pre-enrichment limit: {len(result)} items") - - # Enrich metadata + # 5. Enrich metadata enriched = await RecommendationMetadata.fetch_batch( self.tmdb_service, result, content_type, user_settings=self.user_settings ) logger.info(f"Enriched {len(enriched)} items with full metadata") - # Final filter + # 6. Final filter filtered = filter_watched_by_imdb(enriched, watched_imdb) - rotated = DailyRotation.rotate_items(filtered, rotation_seed) - elapsed_time = time.time() - start_time logger.info( - f"Top picks complete: {len(rotated)} items returned in {elapsed_time:.2f}s " + f"Top picks complete: {len(filtered)} items returned in {elapsed_time:.2f}s " f"(target: {limit}, candidates: {len(all_candidates)}, scored: {len(scored_candidates)})" ) - return rotated[:MAX_CATALOG_ITEMS] - - async def _fetch_recommendations_from_top_items( - self, - library_items: dict[str, list[dict[str, Any]]], - content_type: str, - mtype: str, - ) -> list[dict[str, Any]]: - """ - Fetch recommendations from top items (loved/watched/liked/added). - - Args: - library_items: Library items dict - content_type: Content type - mtype: TMDB media type (movie/tv) - - Returns: - List of candidate items - """ - # Get top items (loved first, then liked, then added, then top watched) - top_items = self.smart_sampler.sample_items(library_items, content_type, max_items=15) - - candidates = [] - tasks = [] - - for item in top_items: - item = item.item - item_id = item.id - if not item_id: - continue - - # Resolve TMDB ID - tmdb_id = await resolve_tmdb_id(item_id, self.tmdb_service) - if not tmdb_id: - continue - - # Fetch recommendations (1 page only) - tasks.append(self.tmdb_service.get_recommendations(tmdb_id, mtype, page=1)) - # tasks.append(self.tmdb_service.get_similar(tmdb_id, mtype, page=1)) - - # Execute all in parallel - logger.info(f"Fetching recommendations from {len(tasks)} top library items") - results = await asyncio.gather(*tasks, return_exceptions=True) - - failed_count = 0 - for res in results: - if isinstance(res, Exception): - failed_count += 1 - logger.debug(f"Recommendation fetch failed: {res}") - continue - candidates.extend(res.get("results", [])) - - if failed_count > 0: - logger.info(f"{failed_count}/{len(tasks)} recommendation fetches failed (expected for items with no recs)") - logger.debug(f"Fetched {len(candidates)} candidates from top items") - - return candidates - - async def _fetch_simkl_recommendations( - self, - library_items: dict[str, list[dict[str, Any]]], - content_type: str, - mtype: str, - ) -> list[dict[str, Any]]: - """ - Fetch recommendations from Simkl for top library items. - - Args: - library_items: Library items dict - content_type: Content type - mtype: TMDB media type (movie/tv) - - Returns: - List of candidate items in TMDB-compatible format - """ - simkl_api_key = self.user_settings.simkl_api_key if self.user_settings else None - if not simkl_api_key: - logger.warning("Simkl API key not found, skipping Simkl recommendations") - return [] - - # Sample top items (same as TMDB flow - 15 items) - top_items = self.smart_sampler.sample_items(library_items, content_type, max_items=15) - - # Extract IMDB IDs - imdb_ids = [] - for scored_item in top_items: - item_id = scored_item.item.id - if item_id and item_id.startswith("tt"): - imdb_ids.append(item_id) - - if not imdb_ids: - logger.warning("No valid IMDB IDs found for Simkl recommendations") - return [] - - logger.info(f"Fetching Simkl recommendations for {len(imdb_ids)} items") - - # Get year range for filtering - year_min = getattr(self.user_settings, "year_min", None) - year_max = getattr(self.user_settings, "year_max", None) - - # Fetch from Simkl (batch optimized with early year filtering) - try: - candidates = await simkl_service.get_recommendations_batch( - imdb_ids, - mtype, - simkl_api_key, - max_per_item=8, - year_min=year_min, - year_max=year_max, - ) - except Exception as e: - logger.error(f"Error fetching Simkl recommendations: {e}") - return [] - - logger.info(f"Fetched {len(candidates)} candidates from Simkl") - return candidates - - def _add_discover_task(self, tasks: list, mtype: str, without_genres: str | None, **kwargs: Any) -> None: - """ - Add a discover task to the list of tasks with default parameters. - """ - sort_by = RecommendationFiltering.get_sort_by_preference(self.user_settings) - params = { - "sort_by": sort_by, - **kwargs, - } - if without_genres: - params["without_genres"] = without_genres - - # Apply global user filters (year range, popularity) - params = apply_discover_filters(params, self.user_settings) - - tasks.append(self.tmdb_service.get_discover(mtype, **params)) - - async def _fetch_discover_with_profile( - self, profile: TasteProfile, content_type: str, mtype: str - ) -> list[dict[str, Any]]: - """ - Fetch discover results using profile features. - - Args: - profile: User taste profile - content_type: Content type - mtype: TMDB media type - - Returns: - List of candidate items - """ - - excluded_genre_ids = RecommendationFiltering.get_excluded_genre_ids(self.user_settings, content_type) - without_genres = "|".join(str(g) for g in excluded_genre_ids) if excluded_genre_ids else None - - logger.debug(f"Excluded genres for {content_type}: {excluded_genre_ids}") - - # Get top features from profile - top_genres = profile.get_top_genres(limit=5) - top_keywords = profile.get_top_keywords(limit=5) - top_directors = profile.get_top_directors(limit=3) - top_cast = profile.get_top_cast(limit=5) - top_eras = profile.get_top_eras(limit=2) - top_countries = profile.get_top_countries(limit=5) - - candidates = [] - tasks = [] - - # Discover with genres - if top_genres: - genre_ids = [g[0] for g in top_genres] - self._add_discover_task( - tasks, - mtype, - without_genres, - with_genres="|".join(str(g) for g in genre_ids), - page=1, - ) - - # Discover with keywords - if top_keywords: - keyword_ids = [k[0] for k in top_keywords] - for page in range(1, 3): # 2 pages - self._add_discover_task( - tasks, - mtype, - without_genres, - with_keywords="|".join(str(k) for k in keyword_ids), - page=page, - ) - - # Discover with directors - if top_directors: - director_ids = [d[0] for d in top_directors] - self._add_discover_task( - tasks, - mtype, - without_genres, - with_crew="|".join(str(d) for d in director_ids), - page=1, - ) - - # Discover with cast - if top_cast: - cast_ids = [c[0] for c in top_cast] - self._add_discover_task( - tasks, - mtype, - without_genres, - with_cast="|".join(str(c) for c in cast_ids), - page=1, - ) - - # Discover with era (year range) - if top_eras: - era = top_eras[0][0] - year_start = self._era_to_year_start(era) - if year_start: - prefix = "first_air_date" if mtype == "tv" else "primary_release_date" - lte_prefix = ( - date.today().isoformat() if year_start + 9 > date.today().year else f"{year_start + 9}-12-31" - ) - params = { - f"{prefix}.gte": f"{year_start}-01-01", - f"{prefix}.lte": lte_prefix, - "page": 1, - } - - self._add_discover_task(tasks, mtype, without_genres, **params) - - # Discover with countries - if top_countries: - country_codes = [c[0] for c in top_countries] - params = { - "with_origin_country": "|".join(country_codes), - "page": 1, - } - self._add_discover_task(tasks, mtype, without_genres, **params) - - # Execute all in parallel - logger.debug(f"Fetching {len(tasks)} discover queries with profile features") - results = await asyncio.gather(*tasks, return_exceptions=True) - - failed_count = 0 - for res in results: - if isinstance(res, Exception): - failed_count += 1 - logger.warning(f"Discover query failed: {res}") - continue - candidates.extend(res.get("results", [])) - - if failed_count > 0: - logger.warning(f"{failed_count}/{len(tasks)} discover queries failed") - logger.debug(f"Fetched {len(candidates)} candidates from discover") - - return candidates - - async def _fetch_trending_and_popular(self, content_type: str, mtype: str) -> list[dict[str, Any]]: - """ - Fetch trending and popular items (for recent items injection). - - Args: - content_type: Content type - mtype: TMDB media type - - Returns: - List of candidate items - """ - candidates = [] - - # Fetch trending (1 page) - try: - trending = await self.tmdb_service.get_trending(mtype, time_window="week", page=1) - candidates.extend(trending.get("results", [])) - except Exception as e: - logger.debug(f"Failed to fetch trending: {e}") - - return candidates - - def _apply_diversity_caps( - self, - scored_candidates: list[tuple[float, dict[str, Any]]], - limit: int, - mtype: str, - ) -> list[dict[str, Any]]: - """ - Apply diversity caps to ensure balanced results. - - Caps: - - Genre: max 50% per genre - - Era: max 50% per era - - Quality: minimum vote_count and rating - - Args: - scored_candidates: List of (score, item) tuples, sorted by score - limit: Target number of items - mtype: Media type for quality checks - - Returns: - Filtered and capped list of items - """ - result = [] - genre_counts = defaultdict(int) - # era_counts = defaultdict(int) - - max_per_genre = int(limit * TOP_PICKS_GENRE_CAP) - # max_per_era = int(limit * TOP_PICKS_ERA_CAP) - - for score, item in scored_candidates: - if len(result) >= limit: - break - - item_id = item.get("id") - if not item_id: - continue - - # Quality threshold - vote_count = item.get("vote_count", 0) - vote_avg = item.get("vote_average", 0) - - # Dynamic check - min_rating, min_votes = RecommendationFiltering.get_quality_thresholds(self.user_settings) - - if vote_count < min_votes: - continue - - # We keep weighted rating check but use dynamic base - wr = RecommendationScoring.weighted_rating(vote_avg, vote_count, C=7.2 if mtype == "tv" else 6.8) - if wr < min_rating: - continue - - # Check genre cap (50% max per genre) - genre_ids = item.get("genre_ids", []) - top_genre = genre_ids[0] if genre_ids else None - - if top_genre: - if genre_counts[top_genre] >= max_per_genre: - continue - - # Add item - result.append(item) - - if top_genre: - genre_counts[top_genre] += 1 - - return result - - def _apply_creator_cap(self, items: list[dict[str, Any]], limit: int) -> list[dict[str, Any]]: - """ - Apply creator cap (max 2 items per director/actor) after enrichment. - - Args: - items: List of enriched items with full metadata - limit: Target limit - - Returns: - Filtered list respecting creator cap - """ - result = [] - creator_counts = defaultdict(int) - - for item in items: - if len(result) >= limit: - break - - # Extract creators from credits - credits = item.get("credits", {}) or {} - crew = credits.get("crew", []) or [] - cast = credits.get("cast", []) or [] - - # Check director cap - directors = [c.get("id") for c in crew if c.get("job", "").lower() == "director" and c.get("id")] - blocked_by_director = False - for dir_id in directors: - if creator_counts[dir_id] >= TOP_PICKS_CREATOR_CAP: - blocked_by_director = True - break - - # Check cast cap (top 3 only) - top_cast = [c.get("id") for c in cast[:3] if c.get("id")] - blocked_by_cast = False - for cast_id in top_cast: - if creator_counts[cast_id] >= TOP_PICKS_CREATOR_CAP: - blocked_by_cast = True - break - - if blocked_by_director or blocked_by_cast: - continue - - # Add item - result.append(item) - - # Update creator counts - for dir_id in directors: - creator_counts[dir_id] += 1 - for cast_id in top_cast: - creator_counts[cast_id] += 1 - - return result - - @staticmethod - def _extract_year(item: dict[str, Any]) -> int | None: - """Extract year from item.""" - release_date = item.get("release_date") or item.get("first_air_date") - if release_date: - try: - return int(str(release_date)[:4]) - except (ValueError, TypeError): - pass - return None - - @staticmethod - def _is_recent_release(item: dict[str, Any], threshold: datetime, mtype: str) -> bool: - """Check if item was released within the threshold (e.g., last 3 months).""" - release_date_str = item.get("release_date") if mtype == "movie" else item.get("first_air_date") - if not release_date_str: - return False - - try: - # Parse release date (format: YYYY-MM-DD) - release_date = datetime.strptime(str(release_date_str)[:10], "%Y-%m-%d") - return release_date >= threshold - except (ValueError, TypeError): - return False - - @staticmethod - def _year_to_era(year: int) -> str: - """Convert year to era bucket.""" - if year < 1970: - return "pre-1970s" - elif year < 1980: - return "1970s" - elif year < 1990: - return "1990s" - elif year < 2000: - return "2000s" - elif year < 2010: - return "2010s" - elif year < 2020: - return "2020s" - else: - return "2020s" - - @staticmethod - def _era_to_year_start(era: str) -> int | None: - """Convert era bucket to starting year.""" - era_map = { - "pre-1970s": 1950, - "1970s": 1970, - "1980s": 1980, - "1990s": 1990, - "2000s": 2000, - "2010s": 2010, - "2020s": 2020, - } - return era_map.get(era) + return filtered[:MAX_CATALOG_ITEMS] diff --git a/app/services/recommendation/utils.py b/app/services/recommendation/utils.py index ab6ea90..89b12ef 100644 --- a/app/services/recommendation/utils.py +++ b/app/services/recommendation/utils.py @@ -1,28 +1,14 @@ from typing import Any -from loguru import logger - -from app.core.constants import DISCOVERY_SETTINGS -from app.services.recommendation.filtering import RecommendationFiltering -from app.services.recommendation.metadata import RecommendationMetadata - def content_type_to_mtype(content_type: str) -> str: return "tv" if content_type in ("tv", "series") else "movie" async def resolve_tmdb_id(item_id: str, tmdb_service: Any) -> int | None: - """ - Resolve item ID to TMDB ID. - - Handles various formats: tmdb:123, tt123456, or plain integer. - - Args: - item_id: Item ID in various formats - tmdb_service: TMDB service instance for IMDB lookups + """Resolve item ID to TMDB ID. - Returns: - TMDB ID or None + Handles formats: tmdb:123, tt123456, or plain integer. """ if item_id.startswith("tmdb:"): try: @@ -37,290 +23,3 @@ async def resolve_tmdb_id(item_id: str, tmdb_service: Any) -> int | None: return int(item_id) except ValueError: return None - - -def filter_watched_by_imdb(enriched: list[dict[str, Any]], watched_imdb: set[str]) -> list[dict[str, Any]]: - """ - Filter enriched items by watched IMDB IDs. - - Checks both the item's 'id' field and '_external_ids.imdb_id' field. - - Args: - enriched: List of enriched metadata items - watched_imdb: Set of watched IMDB IDs - - Returns: - Filtered list excluding watched items - """ - final = [] - for item in enriched: - if item.get("id") in watched_imdb: - continue - if item.get("_external_ids", {}).get("imdb_id") in watched_imdb: - continue - final.append(item) - return final - - -def filter_by_genres( - items: list[dict[str, Any]], - watched_tmdb: set[int], - whitelist: set[int] | None = None, - excluded_ids: list[int] | None = None, -) -> list[dict[str, Any]]: - """ - Filter items by genre whitelist and excluded genres. - - Args: - items: List of candidate items - watched_tmdb: Set of watched TMDB IDs to exclude - whitelist: Optional genre whitelist - excluded_ids: Optional list of excluded genre IDs - - Returns: - Filtered list of items - """ - whitelist = whitelist or set() - excluded_ids = excluded_ids or [] - filtered = [] - - for item in items: - item_id = item.get("id") - if not item_id or item_id in watched_tmdb: - continue - - genre_ids = item.get("genre_ids", []) - - # Excluded genres check - if excluded_ids and any(gid in excluded_ids for gid in genre_ids): - continue - - filtered.append(item) - - return filtered - - -async def pad_to_min( - content_type: str, - existing: list[dict], - min_items: int, - tmdb_service: Any, - user_settings: Any = None, - watched_tmdb: set[int] | None = None, - watched_imdb: set[str] | None = None, -) -> list[dict]: - """ - Pad recommendations to meet minimum item count by fetching trending/popular items. - - Args: - content_type: Content type (movie/series) - existing: Existing recommendations - min_items: Minimum number of items required - tmdb_service: TMDB service instance - user_settings: User settings (optional) - watched_tmdb: Set of watched TMDB IDs (optional) - watched_imdb: Set of watched IMDB IDs (optional) - - Returns: - List of recommendations padded to min_items - """ - need = max(0, int(min_items) - len(existing)) - if need <= 0: - return existing - - # Use provided watched sets (or empty sets if not provided) - watched_tmdb = watched_tmdb or set() - watched_imdb = watched_imdb or set() - excluded_ids = set(RecommendationFiltering.get_excluded_genre_ids(user_settings, content_type)) - - mtype = content_type_to_mtype(content_type) - pool = [] - - try: - tr = await tmdb_service.get_trending(mtype, time_window="week") - pool.extend(tr.get("results", [])[:60]) - tr2 = await tmdb_service.get_top_rated(mtype) - pool.extend(tr2.get("results", [])[:60]) - except Exception as e: - logger.debug(f"Error fetching trending/top-rated for padding: {e}") - return existing - - # Filter pool by user settings (years, popularity) - pool = filter_items_by_settings(pool, user_settings) - - # Get existing TMDB IDs - existing_tmdb = set() - for it in existing: - tid = it.get("_tmdb_id") or it.get("tmdb_id") or it.get("id") - try: - if isinstance(tid, str) and tid.startswith("tmdb:"): - tid = int(tid.split(":")[1]) - existing_tmdb.add(int(tid)) - except Exception: - pass - - # Filter pool - dedup = {} - for it in pool: - tid = it.get("id") - if not tid or tid in existing_tmdb or tid in watched_tmdb: - continue - gids = it.get("genre_ids") or [] - if excluded_ids.intersection(gids): - continue - - # Quality threshold - va, vc = float(it.get("vote_average") or 0.0), int(it.get("vote_count") or 0) - if vc < 200 or va < 6.0: - continue - dedup[tid] = it - if len(dedup) >= need * 3: - break - - if not dedup: - return existing - - # Enrich metadata - meta = await RecommendationMetadata.fetch_batch( - tmdb_service, - list(dedup.values()), - content_type, - user_settings=user_settings, - ) - - # Add to existing, filtering watched items - extra = [] - for it in meta: - if it.get("id") in watched_imdb: - continue - if it.get("_external_ids", {}).get("imdb_id") in watched_imdb: - continue - - # Final check against existing - is_dup = False - for e in existing: - if e.get("id") == it.get("id"): - is_dup = True - break - if is_dup: - continue - - it.pop("_external_ids", None) - extra.append(it) - if len(extra) >= need: - break - - return existing + extra - - -def build_discover_params(user_settings: Any) -> dict[str, Any]: - """ - Build TMDB discover API parameters based on user settings. - """ - params = {} - if not user_settings: - return params - - from datetime import datetime - - current_date = datetime.now() - current_year = current_date.year - - # 1. Year Range - year_min = getattr(user_settings, "year_min", 1970) - year_max = getattr(user_settings, "year_max", current_year) - - # Apply to both movie and tv date fields for convenience in merging - for prefix in ["primary_release_date", "first_air_date"]: - params[f"{prefix}.gte"] = f"{year_min}-01-01" - - # If year_max is current year or greater, use today's date for 'lte' - # relative to the current time. - if year_max >= current_year: - params[f"{prefix}.lte"] = current_date.strftime("%Y-%m-%d") - else: - params[f"{prefix}.lte"] = f"{year_max}-12-31" - - return params - - -def apply_discover_filters(params: dict[str, Any], user_settings: Any) -> dict[str, Any]: - """ - Merge specific discover params with global user settings (years, popularity). - """ - if not user_settings: - return params - - global_params = build_discover_params(user_settings) - - params = {**global_params, **params} - - min_rating, min_votes = RecommendationFiltering.get_quality_thresholds(user_settings) - - # Apply dynamic thresholds if not overridden by stricter local params - if "vote_count.gte" not in params: - params["vote_count.gte"] = min_votes - - if "vote_average.gte" not in params: - params["vote_average.gte"] = min_rating - - return params - - -def filter_items_by_settings( - items: list[dict[str, Any]], user_settings: Any, simkl: bool = False -) -> list[dict[str, Any]]: - """ - Filter items post-fetch based on global user settings (years, popularity). - Used for items from recommendations/similar APIs that don't support early filtering. - """ - if not user_settings: - return items - - year_min = getattr(user_settings, "year_min", 1970) - year_max = getattr(user_settings, "year_max", 2026) - pop_pref = getattr(user_settings, "popularity", "balanced") - - filtered = [] - - for item in items: - # 1. Year Filtering - release_date = item.get("release_date") or item.get("first_air_date") or item.get("released") - if release_date: - try: - year = int(release_date.split("-")[0]) - if year < year_min or year > year_max: - continue - except (ValueError, IndexError): - pass - - params = DISCOVERY_SETTINGS.get(pop_pref, {}) - if not params: - continue - - # determine operations - ops = { - "gte": lambda x, y: x >= y, - "lte": lambda x, y: x <= y, - } - - passes_all_checks = True - for param in params: - t_param, param_ops = param.split(".") - param_operator = ops.get(param_ops) - if not param_operator: - continue - - # skip popularity params if simkl - if simkl and t_param == "popularity": - continue - - item_value = item.get(t_param) - if item_value is None or not param_operator(item_value, params[param]): - passes_all_checks = False - break - - if passes_all_checks: - filtered.append(item) - - return filtered diff --git a/app/services/redis_service.py b/app/services/redis_service.py index 51159a7..9ae31f2 100644 --- a/app/services/redis_service.py +++ b/app/services/redis_service.py @@ -84,6 +84,20 @@ async def delete(self, key: str) -> bool: logger.error(f"Failed to delete key '{key}' from Redis: {exc}") return False + async def expire(self, key: str, ttl: int) -> bool: + """Refresh the TTL on an existing key. Used to keep active users' + caches alive without rewriting the value on every read. + + Returns True if the key existed and the TTL was set, False otherwise. + """ + try: + client = await self.get_client() + result = await client.expire(key, ttl) + return bool(result) + except (redis.RedisError, OSError) as exc: + logger.error(f"Failed to set TTL on key '{key}' in Redis: {exc}") + return False + async def exists(self, key: str) -> bool: """Check if a key exists in Redis. diff --git a/app/services/row_generator.py b/app/services/row_generator.py index b503774..ad9573c 100644 --- a/app/services/row_generator.py +++ b/app/services/row_generator.py @@ -1,551 +1,395 @@ """ Dynamic Row Generator Service. -Generates 3 personalized catalog rows using a tiered sampling system: -- Row 1 (The Core): User's strongest preferences (Gold tier: Top 1-3) -- Row 2 (The Blend): Mixed preferences with higher complexity (Gold+Silver: Top 1-8) -- Row 3 (The Rising Star): Emerging interests (Silver tier: Rank 4-10) +Generates 3 personalized catalog rows from a user's taste profile: +- Row 1 (Core): Strongest preferences +- Row 2 (Blend): Mixed preferences with variety +- Row 3 (Rising Star): Emerging/exploratory interests """ import asyncio import random -from enum import Enum from typing import Any from loguru import logger from pydantic import BaseModel, Field -from app.models.taste_profile import TasteProfile +from app.models.profile import TasteProfile from app.services.gemini import gemini_service from app.services.tmdb.countries import COUNTRY_ADJECTIVES from app.services.tmdb.genre import movie_genres, series_genres from app.services.tmdb.service import TMDBService, get_tmdb_service -GOLD_TIER_LIMIT = 3 # Top 1-3 items -SILVER_TIER_START = 3 # Rank 4+ -SILVER_TIER_END = 10 # Up to Rank 10 +GOLD_END = 3 +SILVER_START = 3 +SILVER_END = 10 -# Available axes for row generation -AXIS_GENRE = "genre" -AXIS_KEYWORD = "keyword" -AXIS_COUNTRY = "country" -AXIS_RUNTIME = "runtime" -AXIS_CREATOR = "creator" +ROLE_ANCHOR = "a" +ROLE_FLAVOR = "f" +ROLE_FALLBACK = "b" +AXIS_GENRE = "g" +AXIS_KEYWORD = "k" +AXIS_COUNTRY = "ct" +AXIS_RUNTIME = "r" +AXIS_CREATOR = "cr" -class AxisRole(str, Enum): - ANCHOR = "anchor" # strong signal, near-required - FLAVOR = "flavor" # boosts relevance, optional - FALLBACK = "fallback" # ranking only, never filtering - - -class RowAxis(BaseModel): - name: str - value: Any - role: AxisRole - weight: float = 1.0 - - -def normalize_keyword(kw: str) -> str: - """Normalize keyword for display.""" - return kw.strip().replace("-", " ").replace("_", " ").title() - - -def get_genre_name(genre_id: int, content_type: str) -> str: - """Get genre name from ID.""" - genre_map = movie_genres if content_type == "movie" else series_genres - return genre_map.get(genre_id, "Movies" if content_type == "movie" else "Series") - - -def get_country_adjective(country_code: str) -> str | None: - """Get country adjective (e.g., 'US' -> 'American').""" - adjectives = COUNTRY_ADJECTIVES.get(country_code, []) - return random.choice(adjectives) if adjectives else None - - -def runtime_to_modifier(bucket: str) -> str | None: - """Get display modifier for runtime bucket.""" - modifiers = { - "short": "Short & Sweet", - "medium": None, # No modifier for medium - "long": "Epic", - } - return modifiers.get(bucket) +class RowDefinition(BaseModel): + """A dynamic catalog row with an ID (encodes TMDB params) and a display title.""" -def sample_from_tier(items: list[tuple[Any, float]], start: int, end: int, count: int = 1) -> list[tuple[Any, float]]: - """Sample random items from a specific tier range.""" - tier_items = items[start:end] - if not tier_items: - return [] - return random.sample(tier_items, min(count, len(tier_items))) + title: str + id: str -def sample_from_gold(items: list[tuple[Any, float]], count: int = 1) -> list[tuple[Any, float]]: - """Sample from Gold tier (Top 1-3).""" - return sample_from_tier(items, 0, GOLD_TIER_LIMIT, count) +class LLMRowTheme(BaseModel): + """Schema for Gemini structured output — a single themed catalog row.""" + title: str = Field(description="Creative, short title for the collection (2-5 words)") + genres: list[int] = Field(description="List of valid TMDB genre IDs") + keywords: list[str] = Field(default_factory=list, description="Specific TMDB keyword names") + country: str | None = Field(default=None, description="ISO 3166-1 country code or null") -def sample_from_silver(items: list[tuple[Any, float]], count: int = 1) -> list[tuple[Any, float]]: - """Sample from Silver tier (Rank 4-10).""" - return sample_from_tier(items, SILVER_TIER_START, SILVER_TIER_END, count) +# --- ID building (format must match theme_based.py parser) --- -def sample_from_gold_silver(items: list[tuple[Any, float]], count: int = 1) -> list[tuple[Any, float]]: - """Sample from combined Gold+Silver tier (Rank 1-10).""" - return sample_from_tier(items, 0, SILVER_TIER_END, count) +def build_row_id(axes: list[tuple[str, str, Any]]) -> str: + """Build row ID from axes. Each axis is (role, axis_type, value). -def build_row_id(axes: list[RowAxis]) -> str: - """Build a unique row ID from axes and their roles.""" + Example output: watchly.theme.a:g28.f:k1234.b:rshort + """ parts = ["watchly.theme"] + sorted_axes = sorted(axes, key=lambda x: (x[0], x[1], str(x[2]))) + for role, axis_type, value in sorted_axes: + if isinstance(value, (list, tuple)): + value = "-".join(str(v) for v in value) + parts.append(f"{role}:{axis_type}{value}") + return ".".join(parts) - role_map = { - AxisRole.ANCHOR: "a", - AxisRole.FLAVOR: "f", - AxisRole.FALLBACK: "b", - } - # Sort axes for consistent IDs - sorted_axes = sorted(axes, key=lambda x: (x.role, x.name, str(x.value))) +# --- Display helpers --- - for axis in sorted_axes: - role_pfx = role_map.get(axis.role, "f") - axis_pfx = { - AXIS_GENRE: "g", - AXIS_KEYWORD: "k", - AXIS_COUNTRY: "ct", - AXIS_RUNTIME: "r", - AXIS_CREATOR: "cr", - }.get(axis.name, "x") - # Handle value formatting - val = axis.value - if isinstance(val, (list, tuple)): - val = "-".join(str(v) for v in val) +def _genre_name(genre_id: int, content_type: str) -> str: + genre_map = movie_genres if content_type == "movie" else series_genres + return genre_map.get(genre_id, "Movies" if content_type == "movie" else "Series") - parts.append(f"{role_pfx}:{axis_pfx}{val}") - return ".".join(parts) +def _country_adjective(code: str) -> str | None: + adjs = COUNTRY_ADJECTIVES.get(code, []) + return random.choice(adjs) if adjs else None -class RowDefinition(BaseModel): - """Defines a dynamic catalog row.""" +def _keyword_display(name: str) -> str: + return name.strip().replace("-", " ").replace("_", " ").title() - title: str - id: str - axes: list[RowAxis] = [] - explanation: str | None = None - expansion_strategy: str | None = None - @property - def is_valid(self) -> bool: - return len(self.axes) > 0 +def _runtime_modifier(bucket: str) -> str | None: + return {"short": "Short & Sweet", "long": "Epic"}.get(bucket) -class LLMRowTheme(BaseModel): - """Schema for structured LLM output - a single themed catalog row.""" +def _pick(items: list, start: int, end: int, exclude: set | None = None) -> Any | None: + """Pick a random item from items[start:end], excluding IDs in `exclude`.""" + pool = items[start:end] + if exclude: + pool = [x for x in pool if x[0] not in exclude] + return random.choice(pool) if pool else None - title: str = Field(description="Creative, short title for the collection (2-5 words)") - genres: list[int] = Field(description="List of valid TMDB genre IDs") - keywords: list[str] = Field(default_factory=list, description="Specific TMDB keyword names") - country: str | None = Field(default=None, description="ISO 3166-1 country code or null") +def _build_core_row( + genres: list[tuple[int, float]], + keywords: list[tuple[int, float]], + runtimes: list[tuple[str, float]], + keyword_names: dict[int, str], + content_type: str, + used_genres: set[int], + used_keywords: set[int], +) -> tuple[list[tuple[str, str, Any]], str] | None: + axes: list[tuple[str, str, Any]] = [] + title_parts: list[str] = [] + + g = _pick(genres, 0, GOLD_END, used_genres) + if not g: + return None + axes.append((ROLE_ANCHOR, AXIS_GENRE, g[0])) + title_parts.append(_genre_name(g[0], content_type)) + used_genres.add(g[0]) + + for _ in range(random.randint(1, 2)): + k = _pick(keywords, 0, GOLD_END, used_keywords) + if k and k[0] in keyword_names: + axes.append((ROLE_FLAVOR, AXIS_KEYWORD, k[0])) + title_parts.append(_keyword_display(keyword_names[k[0]])) + used_keywords.add(k[0]) + + if runtimes: + rt = random.choice(runtimes[:2]) + axes.append((ROLE_FALLBACK, AXIS_RUNTIME, rt[0])) + mod = _runtime_modifier(rt[0]) + if mod: + title_parts.insert(0, mod) + + return (axes, " ".join(title_parts)) + + +def _build_blend_row( + genres: list[tuple[int, float]], + countries: list[tuple[str, float]], + content_type: str, + used_genres: set[int], +) -> tuple[list[tuple[str, str, Any]], str] | None: + axes: list[tuple[str, str, Any]] = [] + title_parts: list[str] = [] + + g = _pick(genres, 0, GOLD_END, used_genres) + if not g: + return None + axes.append((ROLE_ANCHOR, AXIS_GENRE, g[0])) + title_parts.append(_genre_name(g[0], content_type)) + used_genres.add(g[0]) + + use_country = random.choice([True, False]) + if use_country and countries: + c = _pick(countries, 0, SILVER_END) + if c: + axes.append((ROLE_FLAVOR, AXIS_COUNTRY, c[0])) + adj = _country_adjective(c[0]) + if adj: + title_parts.insert(0, adj) + else: + other = [gx for gx in genres if gx[0] != g[0]] + sg = _pick(other, 0, SILVER_END) if other else None + if sg: + axes.append((ROLE_FLAVOR, AXIS_GENRE, sg[0])) + title_parts.append(_genre_name(sg[0], content_type)) + + return (axes, " ".join(title_parts)) + + +def _build_rising_row( + genres: list[tuple[int, float]], + keywords: list[tuple[int, float]], + countries: list[tuple[str, float]], + keyword_names: dict[int, str], + content_type: str, + used_genres: set[int], + used_keywords: set[int], +) -> tuple[list[tuple[str, str, Any]], str] | None: + axes: list[tuple[str, str, Any]] = [] + title_parts: list[str] = [] + + k = _pick(keywords, SILVER_START, SILVER_END, used_keywords) + if not k or k[0] not in keyword_names: + return None + axes.append((ROLE_ANCHOR, AXIS_KEYWORD, k[0])) + title_parts.append(_keyword_display(keyword_names[k[0]])) + used_keywords.add(k[0]) -class RowComponents(BaseModel): - """Internal structure for building a row.""" + g = _pick(genres, SILVER_START, SILVER_END, used_genres) + if g: + axes.append((ROLE_FLAVOR, AXIS_GENRE, g[0])) + title_parts.append(_genre_name(g[0], content_type)) - axes: list[RowAxis] = [] - explanation: str | None = None + if countries: + c = _pick(countries, 0, SILVER_END) + if c: + axes.append((ROLE_FALLBACK, AXIS_COUNTRY, c[0])) - # For title generation - prompt_parts: list[str] = [] - fallback_parts: list[str] = [] + return (axes, " ".join(title_parts)) - def build_prompt(self) -> str: - """Build Gemini prompt from parts.""" - return " + ".join(self.prompt_parts) - def build_fallback(self) -> str: - """Build fallback title from parts.""" - return " ".join(self.fallback_parts) +def build_fallback_rows( + genres: list[tuple[int, float]], + keywords: list[tuple[int, float]], + countries: list[tuple[str, float]], + runtimes: list[tuple[str, float]], + keyword_names: dict[int, str], + content_type: str, +) -> list[tuple[list[tuple[str, str, Any]], str]]: + """Build up to 3 rows as (axes, fallback_title) tuples.""" + rows: list[tuple[list[tuple[str, str, Any]], str]] = [] + used_genres: set[int] = set() + used_keywords: set[int] = set() - def to_dict(self) -> dict[str, Any]: - """Convert to dict for row building.""" - return { - "axes": self.axes, - "explanation": self.explanation, - } + r1 = _build_core_row(genres, keywords, runtimes, keyword_names, content_type, used_genres, used_keywords) + if r1: + rows.append(r1) + r2 = _build_blend_row(genres, countries, content_type, used_genres) + if r2: + rows.append(r2) -class ExtractedFeatures: - """Container for all extracted profile features with keyword names resolved.""" + r3 = _build_rising_row(genres, keywords, countries, keyword_names, content_type, used_genres, used_keywords) + if r3: + rows.append(r3) - def __init__( - self, - genres: list[tuple[int, float]], - keywords: list[tuple[int, float]], - countries: list[tuple[str, float]], - runtimes: list[tuple[str, float]], - creators: list[tuple[int, float]], - keyword_names: dict[int, str], - content_type: str, - ): - self.genres = genres - self.keywords = keywords - self.countries = countries - self.runtimes = runtimes - self.creators = creators - self.keyword_names = keyword_names - self.content_type = content_type - - def get_keyword_name(self, keyword_id: int) -> str | None: - return self.keyword_names.get(keyword_id) - - def get_genre_name(self, genre_id: int) -> str: - return get_genre_name(genre_id, self.content_type) - - -class RowBuilder: - """Builds a single row by sampling from axes with specific roles.""" - - def __init__(self, features: ExtractedFeatures): - self.features = features - self.components = RowComponents() - self.used_axes: set[str] = set() - - def add_axis(self, name: str, value: Any, role: AxisRole, weight: float = 1.0) -> "RowBuilder": - """Add an axis with a specific role and weight.""" - axis = RowAxis(name=name, value=value, role=role, weight=weight) - self.components.axes.append(axis) - - # Build prompt and fallback title parts - display_val = self._get_display_value(name, value) - if display_val: - prefix = "" - if role == AxisRole.ANCHOR: - prefix = "Anchor: " - elif role == AxisRole.FLAVOR: - prefix = "Flavor: " - - self.components.prompt_parts.append(f"{prefix}{name.title()}: {display_val}") - - # For fallback title, we prioritize Anchor and Flavor - if role in (AxisRole.ANCHOR, AxisRole.FLAVOR): - if name == AXIS_COUNTRY: - self.components.fallback_parts.insert(0, display_val) - else: - self.components.fallback_parts.append(display_val) - - self.used_axes.add(f"{name}:{value}") - return self - - def _get_display_value(self, name: str, value: Any) -> str | None: - """Get human-readable value for an axis.""" - if name == AXIS_GENRE: - return self.features.get_genre_name(value) - if name == AXIS_KEYWORD: - return normalize_keyword(self.features.get_keyword_name(value) or "") - if name == AXIS_COUNTRY: - return get_country_adjective(value) - if name == AXIS_RUNTIME: - return runtime_to_modifier(value) - return str(value) - - def build(self) -> RowComponents | None: - """Build and return the row components if valid (has at least one anchor).""" - has_anchor = any(a.role == AxisRole.ANCHOR for a in self.components.axes) - if has_anchor: - return self.components - return None + return rows[:3] class RowGeneratorService: - """Generates dynamic, personalized row definitions from a User Taste Profile.""" + """Generates dynamic, personalized row definitions from a taste profile.""" def __init__(self, tmdb_service: TMDBService | None = None): self.tmdb_service = tmdb_service or get_tmdb_service() async def generate_rows( - self, profile: TasteProfile, content_type: str = "movie", api_key: str | None = None + self, + profile: TasteProfile, + content_type: str = "movie", + api_key: str | None = None, ) -> list[RowDefinition]: - """ - Generate exactly 3 personalized catalog rows. - If api_key is provided, uses LLM to generate creative themes. - Otherwise uses tiered sampling system. - - Returns: - List of RowDefinition - """ - # 1. Extract all features from profile - features = await self._extract_features(profile, content_type) - - # 2. Try LLM generation if key is present + """Generate up to 3 personalized catalog rows.""" + genres = profile.get_top_genres(limit=5) + keywords = profile.get_top_keywords(limit=10) + countries = profile.get_top_countries(limit=2) + runtimes = sorted(profile.runtime_bucket_scores.items(), key=lambda x: x[1], reverse=True) + + keyword_names = await self._resolve_keyword_names([kid for kid, _ in keywords]) + if api_key: try: - llm_rows = await self._generate_rows_with_llm(profile, features, content_type, api_key) + llm_rows = await self._generate_with_llm( + profile, genres, keywords, keyword_names, content_type, api_key + ) if llm_rows: logger.info(f"Generated {len(llm_rows)} LLM-driven rows for {content_type}") return llm_rows except Exception as e: - logger.warning(f"LLM row generation failed, falling back to tiered sampling: {e}") - - # 3. Fallback to Tiered Sampling - rows_data = [] - used_genres = set() - used_keywords = set() - - # Row 1: The Core (Strongest matches) - core_row = self._build_core_row(features, exclude_genres=used_genres, exclude_keywords=used_keywords) - if core_row: - rows_data.append(core_row) - self._update_used_axes(core_row, used_genres, used_keywords) - - # Row 2: The Blend (Mixing themes) - blend_row = self._build_blend_row(features, exclude_genres=used_genres, exclude_keywords=used_keywords) - if blend_row: - rows_data.append(blend_row) - self._update_used_axes(blend_row, used_genres, used_keywords) - - # Row 3: The Rising Star (Exploration) - rising_row = self._build_rising_star_row(features, exclude_genres=used_genres, exclude_keywords=used_keywords) - if rising_row: - rows_data.append(rising_row) - - # 4. Generate titles via server's default Gemini model (gemma) - final_rows = await self._generate_titles(rows_data[:3]) - - logger.info(f"Generated {len(final_rows)} dynamic rows (Tiered Sampling) for {content_type}") - return final_rows - - def _update_used_axes(self, row: RowComponents, used_genres: set, used_keywords: set): - """Track used genres and keywords to ensure row diversity.""" - for axis in row.axes: - if axis.name == AXIS_GENRE: - used_genres.add(axis.value) - elif axis.name == AXIS_KEYWORD: - used_keywords.add(axis.value) - - async def _extract_features(self, profile: TasteProfile, content_type: str) -> ExtractedFeatures: - """Extract all features from profile and resolve keyword names.""" - # Get raw features - genres = profile.get_top_genres(limit=5) - keywords = profile.get_top_keywords(limit=10) - countries = profile.get_top_countries(limit=2) - runtimes = sorted(profile.runtime_bucket_scores.items(), key=lambda x: x[1], reverse=True) - creators = profile.get_top_creators(limit=5) + logger.warning(f"LLM row generation failed, using fallback: {e}") - # Fetch keyword names in parallel - keyword_ids = [k_id for k_id, _ in keywords] - keyword_names_raw = await asyncio.gather( - *[self._get_keyword_name(kid) for kid in keyword_ids], + rows = build_fallback_rows(genres, keywords, countries, runtimes, keyword_names, content_type) + titled = await self._generate_titles(rows) + logger.info(f"Generated {len(titled)} rows (fallback) for {content_type}") + return titled + + # --- Title generation via Gemini --- + + async def _generate_titles(self, rows: list[tuple[list[tuple[str, str, Any]], str]]) -> list[RowDefinition]: + if not rows: + return [] + + prompts = [fallback for _, fallback in rows] + results = await asyncio.gather( + *[gemini_service.generate_content_async(p) for p in prompts], return_exceptions=True, ) - keyword_names = { - kid: name for kid, name in zip(keyword_ids, keyword_names_raw) if name and not isinstance(name, Exception) - } - - return ExtractedFeatures( - genres=genres, - keywords=keywords, - countries=countries, - runtimes=runtimes, - creators=creators, - keyword_names=keyword_names, - content_type=content_type, - ) - async def _get_keyword_name(self, keyword_id: int) -> str | None: - """Fetch keyword name from TMDB.""" - try: - data = await self.tmdb_service.get_keyword_details(keyword_id) - return data.get("name") - except Exception: - return None + final = [] + for i, (axes, fallback) in enumerate(rows): + result = results[i] + title = result.strip() if isinstance(result, str) else fallback + final.append(RowDefinition(title=title, id=build_row_id(axes))) + return final + + # --- LLM-based generation --- - def _build_core_row( + async def _generate_with_llm( self, - features: ExtractedFeatures, - exclude_genres: set[int] | None = None, - exclude_keywords: set[int] | None = None, - ) -> RowComponents | None: - """ - Build 'The Core' row: - Anchor: GENRE (Gold) - Flavor: 1-2 KEYWORDS (Gold) - Fallback: RUNTIME (Gold/Silver) - """ - exclude_genres = exclude_genres or set() - exclude_keywords = exclude_keywords or set() - builder = RowBuilder(features) - - # 1. Anchor: Genre - available_genres = [g for g in features.genres if g[0] not in exclude_genres] - genres = sample_from_gold(available_genres, 1) if available_genres else sample_from_gold(features.genres, 1) - if not genres: - return None - builder.add_axis(AXIS_GENRE, genres[0][0], AxisRole.ANCHOR, 1.0) + profile: TasteProfile, + genres: list[tuple[int, float]], + keywords: list[tuple[int, float]], + keyword_names: dict[int, str], + content_type: str, + api_key: str, + ) -> list[RowDefinition] | None: + genre_map = movie_genres if content_type == "movie" else series_genres + valid_genres = ", ".join(f"{name} (ID: {gid})" for gid, name in genre_map.items()) + + # Build profile context from actual data + top_genre_names = [genre_map.get(gid, f"ID:{gid}") for gid, _ in genres[:5]] + profile_keywords = [name for kid, _ in keywords[:12] if (name := keyword_names.get(kid))] + top_countries = profile.get_top_countries(limit=2) + country_list = [c for c, _ in top_countries] if top_countries else [] + + profile_context = f"Top genres: {', '.join(top_genre_names)}." + if profile_keywords: + profile_context += f" Themes they enjoy: {', '.join(profile_keywords)}." + if country_list: + profile_context += f" Preferred countries: {', '.join(country_list)}." + + keyword_hint = "You can suggest themes from the user's preferences or new ones for discovery." + + prompt = ( + "Based on the user's taste profile below, generate exactly 3 streaming " + f"collections for {content_type}. " + "Use genres (required), keywords, and country when relevant.\n\n" + f"User Profile:\n{profile_context}\n\n" + "Generate 3 rows:\n" + "1. THE CORE — strongest match to their taste\n" + "2. MIXED PREFERENCES — blend with variety\n" + "3. RISING STAR — discovery, adjacent to their taste\n\n" + f"Genres: use ONLY these TMDB Genre IDs: {valid_genres}\n" + f"Keywords: {keyword_hint}\n" + "Country: ISO 3166-1 code or null.\n" + "Each row: title (2-5 words), genres (list of IDs), " + "keywords (list of strings), country (string or null).\n" + "Output a JSON array of 3 objects." + ) - # 2. Flavor: 1-2 Keywords - available_keywords = [k for k in features.keywords if k[0] not in exclude_keywords] - keywords = sample_from_gold(available_keywords, random.randint(1, 2)) if available_keywords else [] - for k_id, _ in keywords: - builder.add_axis(AXIS_KEYWORD, k_id, AxisRole.FLAVOR, 0.7) + data = await gemini_service.generate_structured_async( + prompt=prompt, + response_schema=list[LLMRowTheme], + system_instruction=( + "You are a creative film curator. Design 3 catalog rows from the user's interest summary. " + "Row 1: strong match. Row 2: blend + variety. Row 3: discovery. " + "Use genres, keywords, and country. Output valid JSON only." + ), + api_key=api_key, + ) - # 3. Fallback: Runtime - if features.runtimes: - runtime = random.choice(features.runtimes[:2]) - builder.add_axis(AXIS_RUNTIME, runtime[0], AxisRole.FALLBACK, 0.3) + if not data or not isinstance(data, list): + return None - row = builder.build() - if row: - row.explanation = "The Core: Based on your absolute favorite genres and recurring themes." - return row + profile_kw_map = {name.lower(): kid for kid, name in keyword_names.items()} + final = [] - def _build_blend_row( - self, - features: ExtractedFeatures, - exclude_genres: set[int] | None = None, - exclude_keywords: set[int] | None = None, - ) -> RowComponents | None: - """ - Build 'The Blend' row: - Anchor: GENRE (Gold) - Flavor: COUNTRY or secondary GENRE (Gold/Silver) - """ - exclude_genres = exclude_genres or set() - builder = RowBuilder(features) - - # 1. Anchor: Genre - available_genres = [g for g in features.genres if g[0] not in exclude_genres] - genres = sample_from_gold(available_genres, 1) if available_genres else sample_from_gold(features.genres, 1) - if not genres: - return None - builder.add_axis(AXIS_GENRE, genres[0][0], AxisRole.ANCHOR, 1.0) - - # 2. Flavor: Country or Secondary Genre - flavor_type = random.choice([AXIS_COUNTRY, AXIS_GENRE]) - - if flavor_type == AXIS_COUNTRY and features.countries: - country = sample_from_gold_silver(features.countries, 1) - builder.add_axis(AXIS_COUNTRY, country[0][0], AxisRole.FLAVOR, 0.7) - elif flavor_type == AXIS_GENRE: - other_genres = [g for g in features.genres if g[0] != genres[0][0]] - if other_genres: - sec_genre = sample_from_gold_silver(other_genres, 1) - builder.add_axis(AXIS_GENRE, sec_genre[0][0], AxisRole.FLAVOR, 0.7) - - row = builder.build() - if row: - row.explanation = "The Blend: Mixing your top genres with international flavor or secondary interests." - return row - - def _build_rising_star_row( - self, - features: ExtractedFeatures, - exclude_genres: set[int] | None = None, - exclude_keywords: set[int] | None = None, - ) -> RowComponents | None: - """ - Build 'The Rising Star' row: - Anchor: recent KEYWORD (Silver) - Flavor: GENRE (Silver) - Fallback: COUNTRY (Gold/Silver) - """ - exclude_genres = exclude_genres or set() - exclude_keywords = exclude_keywords or set() - builder = RowBuilder(features) - - # 1. Anchor: Recent Keyword (Sampling from Silver to promote exploration) - available_keywords = [k for k in features.keywords if k[0] not in exclude_keywords] - keywords = sample_from_silver(available_keywords, 1) if available_keywords else [] - if keywords: - builder.add_axis(AXIS_KEYWORD, keywords[0][0], AxisRole.ANCHOR, 1.0) - - # If we couldn't add an anchor, this row fails - if not builder.components.axes: - return None + for item in data: + if isinstance(item, dict): + title, genre_ids, kw_names, country = ( + item.get("title", "Recommended"), + item.get("genres", []), + item.get("keywords", []), + item.get("country"), + ) + else: + title, genre_ids, kw_names, country = item.title, item.genres, item.keywords, item.country - # 2. Flavor: Genre (Silver) - available_genres = [g for g in features.genres if g[0] not in exclude_genres] - genres = sample_from_silver(available_genres, 1) if available_genres else [] - if genres: - builder.add_axis(AXIS_GENRE, genres[0][0], AxisRole.FLAVOR, 0.7) - - # 3. Fallback: Country - if features.countries: - country = sample_from_gold_silver(features.countries, 1) - builder.add_axis(AXIS_COUNTRY, country[0][0], AxisRole.FALLBACK, 0.3) - - row = builder.build() - if row: - row.explanation = "The Rising Star: Exploring emerging interests and newer themes in your history." - return row - - def _build_signature_rows(self, features: ExtractedFeatures) -> list[RowComponents]: - """Generate dynamic signature recipes from user history.""" - signature_rows = [] - - # 1. Top genre × dominant keyword - if features.genres and features.keywords: - builder = RowBuilder(features) - builder.add_axis(AXIS_GENRE, features.genres[0][0], AxisRole.ANCHOR, 1.0) - builder.add_axis(AXIS_KEYWORD, features.keywords[0][0], AxisRole.FLAVOR, 0.7) - row = builder.build() - if row: - row.explanation = "Signature: Your #1 genre paired with your most frequent theme." - signature_rows.append(row) - - # 2. Top genre × preferred runtime - if features.genres and features.runtimes: - builder = RowBuilder(features) - builder.add_axis(AXIS_GENRE, features.genres[0][0], AxisRole.ANCHOR, 1.0) - builder.add_axis(AXIS_RUNTIME, features.runtimes[0][0], AxisRole.FLAVOR, 0.7) - row = builder.build() - if row: - row.explanation = "Signature: Favorite genre fit for your preferred watch duration." - signature_rows.append(row) - - return signature_rows - - async def _generate_titles(self, rows_data: list[RowComponents]) -> list[RowDefinition]: - """Generate titles for tiered sampling rows via server's default Gemini model.""" - if not rows_data: - return [] + axes: list[tuple[str, str, Any]] = [] + for gid in genre_ids: + if int(gid) in genre_map: + axes.append((ROLE_ANCHOR, AXIS_GENRE, int(gid))) - # Build prompts and fire Gemini requests (uses server key + default model) - prompts = [row.build_prompt() for row in rows_data] - gemini_tasks = [gemini_service.generate_content_async(p) for p in prompts] - results = await asyncio.gather(*gemini_tasks, return_exceptions=True) + for kw_name in kw_names: + kid = await self._resolve_keyword_to_id(kw_name, profile_kw_map) + if kid is not None: + axes.append((ROLE_FLAVOR, AXIS_KEYWORD, kid)) - final_rows = [] - for i, row in enumerate(rows_data): - result = results[i] + if country: + axes.append((ROLE_FLAVOR, AXIS_COUNTRY, country)) - # Determine title - if isinstance(result, Exception): - logger.warning(f"Gemini failed for row {i}: {result}") - title = row.build_fallback() - elif result: - title = result.strip() - else: - title = row.build_fallback() + if axes: + final.append(RowDefinition(title=title, id=build_row_id(axes))) - # Build the row ID - row_id = build_row_id(row.axes) + return final if final else None - final_rows.append( - RowDefinition( - title=title, - id=row_id, - **row.to_dict(), - ) - ) + # --- Helpers --- - return final_rows + async def _resolve_keyword_names(self, keyword_ids: list[int]) -> dict[int, str]: + results = await asyncio.gather( + *[self._get_keyword_name(kid) for kid in keyword_ids], + return_exceptions=True, + ) + return {kid: name for kid, name in zip(keyword_ids, results) if isinstance(name, str) and name} + + async def _get_keyword_name(self, keyword_id: int) -> str | None: + try: + data = await self.tmdb_service.get_keyword_details(keyword_id) + return data.get("name") + except Exception: + return None async def _resolve_keyword_to_id(self, kw_name: str, profile_kw_map: dict[str, int]) -> int | None: - """Resolve a keyword name to TMDB ID: profile first, then TMDB search (for discovery).""" kw_lower = str(kw_name).strip().lower() if not kw_lower: return None @@ -562,97 +406,3 @@ async def _resolve_keyword_to_id(self, kw_name: str, profile_kw_map: dict[str, i except Exception: pass return None - - async def _generate_rows_with_llm( - self, - profile: TasteProfile, - features: ExtractedFeatures, - content_type: str, - api_key: str, - ) -> list[RowDefinition] | None: - """Generate rows from the user's interest summary; balance personalization with discovery.""" - try: - summary = profile.interest_summary or "No summary available." - - current_genre_map = movie_genres if content_type == "movie" else series_genres - valid_genre_list = ", ".join([f"{name} (ID: {gid})" for gid, name in current_genre_map.items()]) - - profile_keywords = [name for k_id, _ in features.keywords[:12] if (name := features.get_keyword_name(k_id))] - keyword_hint = ( - ( - f"Themes they already like (you can use these): {', '.join(profile_keywords)}. " - if profile_keywords - else "" - ) - + "You can also suggest new themes for discovery—especially for Rising Star—" - "e.g. adjacent genres or topics they might not have tried yet. We will resolve keywords." - ) - - prompt = ( - "Using only the user's interest summary below, generate exactly 3 streaming collections for" - f" {content_type}. Use genres (required), keywords, and country when relevant.\n\nInterest" - f" Summary:\n{summary}\n\nGenerate 3 rows in this order:\n1. THE CORE — What they will love" - " most: strongest match to their taste (genres + keywords + country if relevant).\n2. MIXED" - " PREFERENCES — Blend of their tastes with more variety (genres + keywords + country if" - " relevant).\n3. RISING STAR — Discovery: suggest themes they might not have explored yet but" - " would likely enjoy (adjacent to their taste, or natural next step). Use genres + keywords +" - " country; openness to new content here.\n\nRules:\n- Genres: use ONLY these TMDB Genre IDs:" - f" {valid_genre_list}\n- Keywords: {keyword_hint}\n- Country: ISO 3166-1 code (e.g. US, KR, JP)" - " or null when relevant.\n- Each row: title (2-5 words), genres (list of IDs), keywords (list" - " of strings), country (string or null).\n- Output a JSON array of 3 objects." - ) - - data = await gemini_service.generate_structured_async( - prompt=prompt, - response_schema=list[LLMRowTheme], - system_instruction=( - "You are a creative film curator. Design 3 catalog rows from the user's interest summary." - " Row 1 (The Core): strong match. Row 2 (Mixed): blend + variety. Row 3 (Rising Star):" - " discovery—suggest new content they would enjoy, not just more of the same. Use genres," - " keywords, and country. Output valid JSON only." - ), - api_key=api_key, - ) - - if not data or not isinstance(data, list): - return None - - final_rows = [] - profile_kw_map = {name.lower(): kid for kid, name in features.keyword_names.items()} - - for item in data: - if isinstance(item, dict): - title = item.get("title", "Recommended") - genre_ids = item.get("genres", []) - kw_names = item.get("keywords", []) - country = item.get("country") - else: - title = item.title - genre_ids = item.genres - kw_names = item.keywords - country = item.country - - builder = RowBuilder(features) - - for gid in genre_ids: - if int(gid) in current_genre_map: - builder.add_axis(AXIS_GENRE, int(gid), AxisRole.ANCHOR) - - for kw_name in kw_names: - kid = await self._resolve_keyword_to_id(kw_name, profile_kw_map) - if kid is not None: - builder.add_axis(AXIS_KEYWORD, kid, AxisRole.FLAVOR) - - if country: - builder.add_axis(AXIS_COUNTRY, country, AxisRole.FLAVOR) - - row_comp = builder.build() - if row_comp and row_comp.axes: - row_id = build_row_id(row_comp.axes) - final_rows.append(RowDefinition(title=title, id=row_id, axes=row_comp.axes)) - - return final_rows if final_rows else None - - except Exception as e: - logger.warning(f"Error in _generate_rows_with_llm: {e}") - return None diff --git a/app/services/rpdb.py b/app/services/rpdb.py deleted file mode 100644 index 8d4e6d6..0000000 --- a/app/services/rpdb.py +++ /dev/null @@ -1,7 +0,0 @@ -class RPDBService: - @staticmethod - def get_poster_url(api_key: str, item_id: str) -> str: - """ - Get poster URL for a specific item by IMDB ID. - """ - return f"https://api.ratingposterdb.com/{api_key}/imdb/poster-default/{item_id}.jpg?fallback=true" diff --git a/app/services/simkl.py b/app/services/simkl.py index c9fb1ae..6b4f2e2 100644 --- a/app/services/simkl.py +++ b/app/services/simkl.py @@ -1,10 +1,14 @@ import asyncio +from datetime import datetime from typing import Any +import httpx from cachetools import TTLCache -from httpx import AsyncClient from loguru import logger +from app.core.base_client import BaseClient +from app.models.history import WatchHistory, WatchHistoryItem + def get_popularity(rank: int | None, N: int = 100000, K: int = 100) -> float: if rank is None: @@ -58,9 +62,33 @@ def normalize_simkl_to_tmdb(item: dict[str, Any], mtype: str) -> dict[str, Any]: class SimklService: def __init__(self): self.base_url = "https://api.simkl.com" - self.client = AsyncClient(timeout=10) + self.client = BaseClient(base_url=self.base_url, timeout=10.0, max_retries=3) self._semaphore = asyncio.Semaphore(10) # Max 10 concurrent requests - self._details_cache: TTLCache = TTLCache(maxsize=1000, ttl=3600) # Cache up to 1000 items # 1 hour TTL + self._details_cache: TTLCache = TTLCache(maxsize=1000, ttl=3600) # 1 hour TTL + + async def close(self) -> None: + await self.client.close() + + async def exchange_code(self, code: str, redirect_uri: str, client_id: str, client_secret: str) -> dict[str, Any]: + """Exchange authorization code for an access token.""" + return await self.client.post( + "/oauth/token", + json={ + "code": code, + "client_id": client_id, + "client_secret": client_secret, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + + async def get_user_settings(self, access_token: str, client_id: str) -> dict[str, Any]: + """Fetch the authenticated user's profile (used to display 'Connected as ...').""" + headers = { + "Authorization": f"Bearer {access_token}", + "simkl-api-key": client_id, + } + return await self.client.get("/users/settings", headers=headers) async def _fetch_with_semaphore(self, coro): """Execute a coroutine with semaphore for rate limiting.""" @@ -68,45 +96,115 @@ async def _fetch_with_semaphore(self, coro): return await coro async def get_trending(self, api_key: str): - url = f"{self.base_url}/movies/trending" - params = {"client_id": api_key} try: - response = await self.client.get(url, params=params, follow_redirects=True) - response.raise_for_status() - json_response = response.json() - return json_response - - except Exception as e: - logger.error(f"Error fetching details from Simkl: {e}") + return await self.client.get("/movies/trending", params={"client_id": api_key}) + except httpx.HTTPStatusError as e: + # 401/403 indicate the user's Simkl token was revoked — let those + # propagate so callers can clear the token and prompt re-auth. + if e.response.status_code in (401, 403): + raise + logger.warning(f"Simkl trending returned {e.response.status_code}: {e}") + return [] + except httpx.RequestError as e: + logger.warning(f"Simkl trending request failed: {e}") return [] async def get_item_details(self, simkl_id, mtype: str, api_key: str) -> dict[str, Any]: """Fetch full item details from Simkl with caching.""" - # Create cache key cache_key = f"{simkl_id}:{mtype}" - # Check cache first if cache_key in self._details_cache: logger.debug(f"Cache hit for Simkl item {simkl_id}") return self._details_cache[cache_key] - # Fetch from API mtype_path = "movies" if mtype == "movie" else "tv" - url = f"{self.base_url}/{mtype_path}/{simkl_id}" - params = {"client_id": api_key, "extended": "full"} try: - response = await self.client.get(url, params=params, follow_redirects=True) - response.raise_for_status() - result = response.json() - - # Store in cache + result = await self.client.get( + f"/{mtype_path}/{simkl_id}", + params={"client_id": api_key, "extended": "full"}, + ) self._details_cache[cache_key] = result return result - - except Exception as e: - logger.error(f"Error fetching details from Simkl: {e}") + except httpx.HTTPStatusError as e: + if e.response.status_code in (401, 403): + raise + logger.warning(f"Simkl item details {simkl_id} returned {e.response.status_code}: {e}") + return {} + except httpx.RequestError as e: + logger.warning(f"Simkl item details {simkl_id} request failed: {e}") return {} + async def get_history(self, access_token: str, client_id: str) -> WatchHistory: + """Fetch watch history from Simkl using OAuth access token.""" + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {access_token}", + "simkl-api-key": client_id, + } + + results = await asyncio.gather( + self.client.get("/sync/all-items/movies", headers=headers), + self.client.get("/sync/all-items/shows", headers=headers), + return_exceptions=True, + ) + + items: list[WatchHistoryItem] = [] + seen: set[str] = set() + + for idx, result in enumerate(results): + if isinstance(result, Exception): + logger.warning(f"Simkl sync request failed: {result}") + continue + data = result if isinstance(result, dict) else {} + mtype = "movie" if idx == 0 else "series" + entries = data.get("movies", []) if idx == 0 else data.get("shows", []) + + for entry in entries: + media = entry.get("movie") or entry.get("show") or {} + imdb_id = media.get("ids", {}).get("imdb") + if not imdb_id or imdb_id in seen: + continue + seen.add(imdb_id) + + user_rating = entry.get("user_rating") + rating = float(user_rating) if user_rating is not None else None + + last_watched = None + raw_date = entry.get("last_watched_at") + if raw_date: + try: + last_watched = datetime.fromisoformat(str(raw_date).replace("Z", "+00:00")) + except (ValueError, TypeError): + pass + + # Prefer the play/episode count Simkl reports; fall back to 1 + # so 'watched but unmarked-count' items still register as seen. + raw_plays = ( + entry.get("total_plays_count") + if mtype == "movie" + else entry.get("watched_episodes_count") or entry.get("total_plays_count") + ) + try: + watch_count = max(int(raw_plays or 0), 1) + except (TypeError, ValueError): + watch_count = 1 + + items.append( + WatchHistoryItem( + imdb_id=imdb_id, + type=mtype, + name=media.get("title", ""), + rating=rating, + watch_count=watch_count, + completion=1.0, + last_watched=last_watched, + source="simkl", + ) + ) + + logger.info(f"Simkl history: {len(items)} items") + return WatchHistory(items=items, source="simkl") + async def get_recommendations(self, imdb_id: str, mtype: str, api_key: str) -> list[dict[str, Any]]: """Get recommendations for a single item (original method for item-based).""" item_details = await self.get_item_details(imdb_id, mtype, api_key) diff --git a/app/services/stremio/library.py b/app/services/stremio/library.py index f17c890..9c531d7 100644 --- a/app/services/stremio/library.py +++ b/app/services/stremio/library.py @@ -1,12 +1,152 @@ import asyncio +from datetime import datetime from typing import Any from async_lru import alru_cache from loguru import logger +from app.models.history import WatchHistory, WatchHistoryItem +from app.models.library import LibraryCollection, StremioLibraryItem, StremioState from app.services.stremio.client import StremioClient, StremioLikesClient +def stremio_library_to_watch_history(library: LibraryCollection) -> WatchHistory: + """Convert typed LibraryCollection to unified WatchHistory format.""" + items: list[WatchHistoryItem] = [] + seen: set[str] = set() + + category_items = [ + (library.loved, True, False), + (library.liked, False, True), + (library.watched, False, False), + (library.added, False, False), + ] + + for lib_items, is_loved, is_liked in category_items: + for item in lib_items: + imdb_id = item.id + if not imdb_id.startswith("tt") or imdb_id in seen: + continue + seen.add(imdb_id) + + state = item.state + duration = state.duration + time_watched = state.timeWatched + times_watched = state.timesWatched + flagged_watched = state.flaggedWatched + + if flagged_watched > 0 or times_watched > 0: + completion = 1.0 + elif duration > 0: + completion = min(time_watched / duration, 1.0) + else: + completion = 0.0 + + rating: float | None = None + if is_loved or item.is_loved: + rating = 9.0 + elif is_liked or item.is_liked: + rating = 7.0 + + last_watched: datetime | None = state.lastWatched + if not last_watched and item.mtime: + try: + last_watched = datetime.fromisoformat(str(item.mtime).replace("Z", "+00:00")) + except (ValueError, TypeError): + pass + + items.append( + WatchHistoryItem( + imdb_id=imdb_id, + type=item.type, + name=item.name, + rating=rating, + watch_count=max(times_watched, 1) if completion > 0 else 0, + completion=completion, + last_watched=last_watched, + source="stremio", + ) + ) + + return WatchHistory(items=items, source="stremio") + + +def _history_item_to_library_item(item: WatchHistoryItem, is_loved: bool, is_liked: bool) -> StremioLibraryItem: + state_kwargs: dict[str, Any] = {} + if item.last_watched: + state_kwargs["lastWatched"] = item.last_watched + state_kwargs["timesWatched"] = max(item.watch_count, 0) + if item.completion >= 1.0: + state_kwargs["flaggedWatched"] = 1 + state_kwargs["timesWatched"] = max(item.watch_count, 1) + elif item.completion > 0: + state_kwargs["duration"] = 6000 + state_kwargs["timeWatched"] = int(6000 * item.completion) + + return StremioLibraryItem( + _id=item.imdb_id, + type=item.type, + name=item.name, + state=StremioState(**state_kwargs), + temp=False, + removed=False, + _is_loved=is_loved, + _is_liked=is_liked, + ) + + +def watch_history_to_library_collection(history: WatchHistory) -> LibraryCollection: + """Convert an external WatchHistory (Trakt/Simkl) into a LibraryCollection. + + Bucketing rules: + loved: rating >= 9, OR no rating + watch_count >= 2 (rewatch as love proxy) + liked: 7 <= rating < 9 + watched: everything else with any completion/watch signal + + Items without IMDb IDs are skipped — downstream code keys on `tt…` / `tmdb:…` + everywhere and dropping them up front avoids fanning empty IDs into TMDB lookups. + """ + loved: list[StremioLibraryItem] = [] + liked: list[StremioLibraryItem] = [] + watched: list[StremioLibraryItem] = [] + seen: set[str] = set() + + for item in history.items: + if not item.imdb_id or item.imdb_id in seen: + continue + seen.add(item.imdb_id) + + rating = item.rating + if rating is not None and rating >= 9.0: + bucket = "loved" + elif rating is not None and rating >= 7.0: + bucket = "liked" + elif rating is None and item.watch_count >= 2: + bucket = "loved" + else: + bucket = "watched" + + is_loved = bucket == "loved" + is_liked = bucket == "liked" + lib_item = _history_item_to_library_item(item, is_loved, is_liked) + + if bucket == "loved": + loved.append(lib_item) + elif bucket == "liked": + liked.append(lib_item) + else: + watched.append(lib_item) + + return LibraryCollection( + loved=loved, + liked=liked, + watched=watched, + added=[], + removed=[], + source=history.source or "stremio", + ) + + class StremioLibraryService: """ Handles fetching and processing of user's Stremio library and likes. @@ -33,7 +173,7 @@ async def get_likes_by_type(self, auth_token: str, media_type: str, status: str logger.exception(f"Failed to fetch {status} {media_type} items: {e}") return [] - async def get_library_items(self, auth_key: str) -> dict[str, list[dict[str, Any]]]: + async def get_library_items(self, auth_key: str) -> LibraryCollection: """ Fetch all library items and categorize them (watched, loved, added, removed). """ @@ -111,24 +251,21 @@ async def get_library_items(self, auth_key: str) -> dict[str, list[dict[str, Any all_raw_items.append(virtual_item) existing_library_ids.add(item_id) - # 3. Categorize items - watched: list[dict] = [] - loved: list[dict] = [] - added: list[dict] = [] - removed: list[dict] = [] - liked: list[dict] = [] - - # Create sets for faster lookup - # loved_set = set(loved_movies + loved_series) - # liked_set = set(liked_movies + liked_series) + # 3. Categorize items and convert to typed models at the boundary + watched: list[StremioLibraryItem] = [] + loved: list[StremioLibraryItem] = [] + added: list[StremioLibraryItem] = [] + removed: list[StremioLibraryItem] = [] + liked: list[StremioLibraryItem] = [] for item in all_raw_items: # Basic validation if item.get("type") not in ["movie", "series"]: continue item_id = item.get("_id", "") - if not item_id.startswith("tt") and not item_id.startswith("tmdb:"): - # either imdb id or tmdb id should be there. + # Downstream history/profile pipeline assumes IMDb ids; tmdb-only + # items can't be converted and would be silently dropped later. + if not item_id.startswith("tt"): continue # Check Watched status @@ -141,36 +278,35 @@ async def get_library_items(self, auth_key: str) -> dict[str, list[dict[str, Any is_completion_high = duration > 0 and (time_watched / duration) >= 0.7 is_watched = times_watched > 0 or flagged_watched > 0 or is_completion_high - # if item is loved or liked and but not watched, then also we need to add it - # as users might not have watched it in stremio itself. + # Set enrichment flags before conversion if item_id in loved_set: item["_is_loved"] = True - loved.append(item) - elif item_id in liked_set: item["_is_liked"] = True - liked.append(item) - elif is_watched: - watched.append(item) + # Convert raw dict to typed model + try: + typed_item = StremioLibraryItem(**item) + except Exception: + continue + # Categorize + if item_id in loved_set: + loved.append(typed_item) + elif item_id in liked_set: + liked.append(typed_item) + elif is_watched: + watched.append(typed_item) elif not item.get("removed") and not item.get("temp"): - # item has not removed and item is not temporary meaning item is not - # added by stremio itself on user watch - added.append(item) + added.append(typed_item) else: continue - # elif item.get("removed"): - # # do not do anything with removed items - # # removed.append(item) - # continue - - # 4. Sort watched items by recency - def sort_by_recency(x: dict): - state = x.get("state", {}) or {} + + # 4. Sort by recency + def sort_by_recency(x: StremioLibraryItem): return ( - str(state.get("lastWatched") or str(x.get("_mtime") or "")), - x.get("_mtime") or "", + str(x.state.lastWatched or x.mtime or ""), + x.mtime or "", ) watched.sort(key=sort_by_recency, reverse=True) @@ -185,13 +321,14 @@ def sort_by_recency(x: dict): f" {len(removed)} removed items" ) - return { - "watched": watched, - "loved": loved, - "liked": liked, - "added": added, - "removed": removed, - } + return LibraryCollection( + watched=watched, + loved=loved, + liked=liked, + added=added, + removed=removed, + source="stremio", + ) except Exception as e: logger.exception(f"Error processing library items: {e}") - return {"watched": [], "loved": [], "liked": [], "added": [], "removed": []} + return LibraryCollection() diff --git a/app/services/tmdb/service.py b/app/services/tmdb/service.py index d4b5b11..00e3fe4 100644 --- a/app/services/tmdb/service.py +++ b/app/services/tmdb/service.py @@ -1,6 +1,7 @@ import functools from typing import Any +import httpx from async_lru import alru_cache from loguru import logger @@ -46,9 +47,14 @@ async def find_by_imdb_id(self, imdb_id: str) -> tuple[int | None, str | None]: if tmdb_id: return tmdb_id, "tv" + return None, None + except httpx.HTTPStatusError as e: + # Log 404 as warning (item just not in TMDB), 5xx as error. + level = "warning" if e.response.status_code == 404 else "error" + getattr(logger, level)(f"TMDB find_by_imdb_id({imdb_id}) HTTP {e.response.status_code}: {e}") return None, None except Exception as e: - logger.exception(f"Error finding TMDB ID for IMDB {imdb_id}: {e}") + logger.exception(f"Unexpected error finding TMDB ID for IMDB {imdb_id}: {e}") return None, None @alru_cache(maxsize=500, ttl=86400) diff --git a/app/services/token_store.py b/app/services/token_store.py index 2f0a8ae..599c2c5 100644 --- a/app/services/token_store.py +++ b/app/services/token_store.py @@ -128,6 +128,28 @@ async def store_user_data(self, user_id: str, payload: dict[str, Any]) -> str: storage_data["settings"]["tmdb_api_key"] = self.encrypt_token(tmdb_api_key) except Exception as exc: logger.warning(f"Failed to encrypt tmdb_api_key for {redact_token(user_id)}: {exc}") + + # Encrypt trakt tokens if present + if storage_data.get("settings") and isinstance(storage_data["settings"], dict): + for trakt_field in ("trakt_access_token", "trakt_refresh_token"): + value = storage_data["settings"].get(trakt_field) + if value: + try: + if not value.startswith("gAAAAAB"): + storage_data["settings"][trakt_field] = self.encrypt_token(value) + except Exception as exc: + logger.warning(f"Failed to encrypt {trakt_field} for {redact_token(user_id)}: {exc}") + + # Encrypt simkl_access_token if present + if storage_data.get("settings") and isinstance(storage_data["settings"], dict): + simkl_access_token = storage_data["settings"].get("simkl_access_token") + if simkl_access_token: + try: + if not simkl_access_token.startswith("gAAAAAB"): + storage_data["settings"]["simkl_access_token"] = self.encrypt_token(simkl_access_token) + except Exception as exc: + logger.warning(f"Failed to encrypt simkl_access_token for {redact_token(user_id)}: {exc}") + json_str = json.dumps(storage_data) if settings.TOKEN_TTL_SECONDS and settings.TOKEN_TTL_SECONDS > 0: @@ -177,14 +199,15 @@ async def _migrate_poster_rating_format_raw(self, token: str, redis_key: str, da needs_save = True # Case 2: Clean up deprecated rpdb_key field if it exists (even if empty/null) - # Remove it since we've migrated to poster_rating or it's no longer needed + # Remove it since we've migrated to poster_rating or it's no longer needed. + # Do not overwrite a valid migrated poster_rating payload. if "rpdb_key" in settings_dict: settings_dict.pop("rpdb_key") - # keep empty poster_rating field for now - settings_dict["poster_rating"] = { - "provider": "rpdb", - "api_key": None, - } + if not settings_dict.get("poster_rating"): + settings_dict["poster_rating"] = { + "provider": "rpdb", + "api_key": None, + } if not needs_save: # Only log if we didn't already log migration logger.info(f"[MIGRATION] Removing deprecated rpdb_key field for {redact_token(token)}") needs_save = True @@ -224,7 +247,10 @@ async def get_user_data(self, token: str) -> dict[str, Any] | None: pass return data - @alru_cache(maxsize=2000, ttl=43200) + # 5-minute TTL: keeps reads cheap under bursty traffic but bounds the window + # in which a deleted token can keep authenticating on a worker that didn't + # observe the local cache invalidation (e.g. multi-worker deployments). + @alru_cache(maxsize=2000, ttl=300) async def _get_user_data_cached(self, token: str) -> dict[str, Any] | None: logger.debug(f"[REDIS] Cache miss. Fetching data from redis for {token}") key = self._format_key(token) @@ -294,6 +320,25 @@ async def _get_user_data_cached(self, token: str) -> dict[str, Any] | None: except Exception as e: logger.debug(f"Decryption failed for tmdb_api_key associated with {redact_token(token)}: {e}") + # Decrypt trakt tokens + for trakt_field in ("trakt_access_token", "trakt_refresh_token"): + value = data["settings"].get(trakt_field) + if value: + try: + if value.startswith("gAAAAA"): + data["settings"][trakt_field] = self.decrypt_token(value) + except Exception as e: + logger.debug(f"Decryption failed for {trakt_field} associated with {redact_token(token)}: {e}") + + # Decrypt simkl_access_token + simkl_access_token = data["settings"].get("simkl_access_token") + if simkl_access_token: + try: + if simkl_access_token.startswith("gAAAAA"): + data["settings"]["simkl_access_token"] = self.decrypt_token(simkl_access_token) + except Exception as e: + logger.debug(f"Decryption failed for simkl_access_token associated with {redact_token(token)}: {e}") + return data async def delete_token(self, token: str = None, key: str = None) -> None: diff --git a/app/services/trakt.py b/app/services/trakt.py new file mode 100644 index 0000000..834b7d7 --- /dev/null +++ b/app/services/trakt.py @@ -0,0 +1,179 @@ +import asyncio +from typing import Any + +from loguru import logger + +from app.core.base_client import BaseClient +from app.core.config import settings +from app.models.history import WatchHistory, WatchHistoryItem + + +class TraktService: + """Service for interacting with the Trakt API.""" + + BASE_URL = "https://api.trakt.tv" + + def __init__(self): + self.client = BaseClient(base_url=self.BASE_URL, timeout=15.0, max_retries=3) + + async def close(self) -> None: + await self.client.close() + + def _headers(self, access_token: str) -> dict[str, str]: + return { + "Content-Type": "application/json", + "trakt-api-version": "2", + "trakt-api-key": settings.TRAKT_CLIENT_ID or "", + "Authorization": f"Bearer {access_token}", + } + + async def get_user_info(self, access_token: str) -> dict[str, Any]: + """GET /users/me - validate token and get username.""" + return await self.client.get("/users/me", headers=self._headers(access_token)) + + async def exchange_code(self, code: str, redirect_uri: str) -> dict[str, Any]: + """Exchange authorization code for tokens.""" + return await self.client.post( + "/oauth/token", + json={ + "code": code, + "client_id": settings.TRAKT_CLIENT_ID, + "client_secret": settings.TRAKT_CLIENT_SECRET, + "redirect_uri": redirect_uri, + "grant_type": "authorization_code", + }, + ) + + async def refresh_token(self, refresh_token: str, redirect_uri: str) -> dict[str, Any]: + """Refresh expired Trakt access token.""" + return await self.client.post( + "/oauth/token", + json={ + "refresh_token": refresh_token, + "client_id": settings.TRAKT_CLIENT_ID, + "client_secret": settings.TRAKT_CLIENT_SECRET, + "redirect_uri": redirect_uri, + "grant_type": "refresh_token", + }, + ) + + async def get_history(self, access_token: str) -> WatchHistory: + """Fetch watched + rated items, return as WatchHistory.""" + headers = self._headers(access_token) + + # Fetch all 4 endpoints in parallel; BaseClient returns parsed JSON + # and handles retry on 429/5xx internally. + results = await asyncio.gather( + self.client.get("/users/me/watched/movies", headers=headers), + self.client.get("/users/me/watched/shows", headers=headers), + self.client.get("/users/me/ratings/movies", headers=headers), + self.client.get("/users/me/ratings/shows", headers=headers), + return_exceptions=True, + ) + + watched_movies = self._safe_list(results[0], "watched/movies") + watched_shows = self._safe_list(results[1], "watched/shows") + rated_movies = self._safe_list(results[2], "ratings/movies") + rated_shows = self._safe_list(results[3], "ratings/shows") + + # Build rating lookup: imdb_id -> rating (1-10) + ratings: dict[str, float] = {} + for item in rated_movies + rated_shows: + media = item.get("movie") or item.get("show") or {} + imdb_id = media.get("ids", {}).get("imdb") + if imdb_id and item.get("rating"): + ratings[imdb_id] = float(item["rating"]) + + # Convert watched items to WatchHistoryItem + items: list[WatchHistoryItem] = [] + seen_ids: set[str] = set() + + for entry in watched_movies: + movie = entry.get("movie", {}) + imdb_id = movie.get("ids", {}).get("imdb") + if not imdb_id or imdb_id in seen_ids: + continue + seen_ids.add(imdb_id) + items.append( + WatchHistoryItem( + imdb_id=imdb_id, + type="movie", + name=movie.get("title", ""), + rating=ratings.get(imdb_id), + watch_count=entry.get("plays", 1), + completion=1.0, + last_watched=self._parse_date(entry.get("last_watched_at")), + source="trakt", + ) + ) + + for entry in watched_shows: + show = entry.get("show", {}) + imdb_id = show.get("ids", {}).get("imdb") + if not imdb_id or imdb_id in seen_ids: + continue + seen_ids.add(imdb_id) + items.append( + WatchHistoryItem( + imdb_id=imdb_id, + type="series", + name=show.get("title", ""), + rating=ratings.get(imdb_id), + watch_count=entry.get("plays", 1), + completion=1.0, + last_watched=self._parse_date(entry.get("last_watched_at")), + source="trakt", + ) + ) + + # Add rated-but-not-watched items (user rated without watching on Trakt) + for item in rated_movies + rated_shows: + media = item.get("movie") or item.get("show") or {} + imdb_id = media.get("ids", {}).get("imdb") + if not imdb_id or imdb_id in seen_ids: + continue + seen_ids.add(imdb_id) + mtype = "movie" if "movie" in item else "series" + items.append( + WatchHistoryItem( + imdb_id=imdb_id, + type=mtype, + name=media.get("title", ""), + rating=float(item.get("rating", 0)), + watch_count=0, + completion=0.0, + last_watched=self._parse_date(item.get("rated_at")), + source="trakt", + ) + ) + + logger.info(f"Trakt history: {len(items)} items ({len(ratings)} rated)") + return WatchHistory(items=items, source="trakt") + + @staticmethod + def _safe_list(result, label: str) -> list: + if isinstance(result, Exception): + logger.warning(f"Trakt {label} request failed: {result}") + return [] + # BaseClient returns dict for JSON objects; Trakt list endpoints return + # arrays which BaseClient parses to list — but its type is annotated as + # dict. Accept either shape defensively. + if isinstance(result, list): + return result + if isinstance(result, dict) and not result: + return [] + return result if isinstance(result, list) else [] + + @staticmethod + def _parse_date(date_str: str | None): + if not date_str: + return None + try: + from datetime import datetime + + return datetime.fromisoformat(date_str.replace("Z", "+00:00")) + except (ValueError, TypeError): + return None + + +trakt_service = TraktService() diff --git a/app/services/translation.py b/app/services/translation.py index c0f0123..5766da2 100644 --- a/app/services/translation.py +++ b/app/services/translation.py @@ -18,6 +18,9 @@ ("fr", "More Like"): "Titres similaires à", ("fr", "More like"): "Titres similaires à", ("fr", "Because you watched"): "Parce que vous avez regardé", + ("fr", "Because you loved"): "Parce que vous avez adoré", + ("de", "Because you watched"): "Weil du angesehen hast", + ("de", "Because you loved"): "Weil du geliebt hast", ("fr", "Genre & Keyword Catalogs"): "Genres et mots-clés", ("fr", "From your favourite Creators"): "De vos créateurs préférés", ("fr", "Based on what you loved"): "D'après vos coups de cœur", @@ -41,7 +44,6 @@ def _normalize_german_formality(text: str) -> str: class TranslationService: - @alru_cache(maxsize=1000, ttl=7 * 24 * 60 * 60) async def translate(self, text: str, target_lang: str | None) -> str: if not text or not target_lang: return text @@ -56,19 +58,24 @@ async def translate(self, text: str, target_lang: str | None) -> str: return _STATIC_TRANSLATIONS[static_key] try: - loop = asyncio.get_running_loop() - - translated = await loop.run_in_executor( - None, lambda: GoogleTranslator(source="auto", target=lang).translate(text) - ) - result = translated if translated else text - if lang == "de": - result = _normalize_german_formality(result) - return result + return await self._translate_cached(text, lang) except Exception as e: + # Fall back to source text on failure but don't cache the fallback — + # otherwise a transient API blip poisons the cache for 7 days. logger.exception(f"Translation failed for '{text}' to '{lang}': {e}") return text + @alru_cache(maxsize=1000, ttl=7 * 24 * 60 * 60) + async def _translate_cached(self, text: str, lang: str) -> str: + loop = asyncio.get_running_loop() + translated = await loop.run_in_executor( + None, lambda: GoogleTranslator(source="auto", target=lang).translate(text) + ) + result = translated if translated else text + if lang == "de": + result = _normalize_german_formality(result) + return result + translation_service = TranslationService() diff --git a/app/services/user_cache.py b/app/services/user_cache.py index d1ace43..ce77d90 100644 --- a/app/services/user_cache.py +++ b/app/services/user_cache.py @@ -5,9 +5,10 @@ from loguru import logger -from app.core.constants import CATALOG_KEY, LIBRARY_ITEMS_KEY, PROFILE_KEY, WATCHED_SETS_KEY +from app.core.constants import CATALOG_KEY, LIBRARY_ITEMS_KEY, PROFILE_KEY, USER_CACHE_TTL_SECONDS, WATCHED_SETS_KEY from app.core.security import redact_token -from app.models.taste_profile import TasteProfile +from app.models.library import LibraryCollection +from app.models.profile import TasteProfile from app.services.redis_service import redis_service @@ -39,42 +40,29 @@ def _last_profile_build_key(token: str, content_type: str) -> str: # Library Items Methods - async def get_library_items(self, token: str) -> dict[str, Any] | None: - """ - Get cached library items for a user. - - Args: - token: User token - - Returns: - Library items dictionary, or None if not cached - """ + async def get_library_items(self, token: str) -> LibraryCollection | None: + """Get cached library items for a user.""" key = self._library_items_key(token) cached = await redis_service.get(key) if cached: try: - return json.loads(cached) - except json.JSONDecodeError as e: + data = json.loads(cached) + # Refresh TTL on read so active users' caches stay warm. + await redis_service.expire(key, USER_CACHE_TTL_SECONDS) + return LibraryCollection.model_validate(data) + except (json.JSONDecodeError, Exception) as e: logger.warning(f"Failed to decode cached library items for {redact_token(token)}...: {e}") return None return None - async def set_library_items(self, token: str, library_items: dict[str, Any]) -> None: - """ - Cache library items for a user. - - Args: - token: User token - library_items: Library items dictionary to cache - """ + async def set_library_items(self, token: str, library_items: LibraryCollection) -> None: + """Cache library items for a user.""" key = self._library_items_key(token) - await redis_service.set(key, json.dumps(library_items)) + await redis_service.set(key, library_items.model_dump_json(by_alias=True), USER_CACHE_TTL_SECONDS) logger.debug(f"[{redact_token(token)}...] Cached library items") - # Invalidate all catalog caches when library items are updated - # This ensures catalogs are regenerated with fresh library data await self.invalidate_all_catalogs(token) async def invalidate_library_items(self, token: str) -> None: @@ -106,7 +94,9 @@ async def get_profile(self, token: str, content_type: str) -> TasteProfile | Non if cached: try: - return TasteProfile.model_validate_json(cached) + profile = TasteProfile.model_validate_json(cached) + await redis_service.expire(key, USER_CACHE_TTL_SECONDS) + return profile except (json.JSONDecodeError, ValueError) as e: logger.warning(f"Failed to decode cached profile for {redact_token(token)}.../{content_type}: {e}") return None @@ -123,7 +113,7 @@ async def set_profile(self, token: str, content_type: str, profile: TasteProfile profile: TasteProfile instance to cache """ key = self._profile_key(token, content_type) - await redis_service.set(key, profile.model_dump_json()) + await redis_service.set(key, profile.model_dump_json(), USER_CACHE_TTL_SECONDS) logger.debug(f"[{redact_token(token)}...] Cached profile for {content_type}") async def invalidate_profile(self, token: str, content_type: str) -> None: @@ -159,6 +149,7 @@ async def get_watched_sets(self, token: str, content_type: str) -> tuple[set[int data = json.loads(cached) watched_tmdb = set(data.get("watched_tmdb", [])) watched_imdb = set(data.get("watched_imdb", [])) + await redis_service.expire(key, USER_CACHE_TTL_SECONDS) return (watched_tmdb, watched_imdb) except (json.JSONDecodeError, KeyError, TypeError) as e: logger.warning(f"Failed to decode cached watched sets for {redact_token(token)}.../{content_type}: {e}") @@ -187,7 +178,7 @@ async def set_watched_sets( "watched_tmdb": list(watched_tmdb), "watched_imdb": list(watched_imdb), } - await redis_service.set(key, json.dumps(data)) + await redis_service.set(key, json.dumps(data), USER_CACHE_TTL_SECONDS) logger.debug(f"[{redact_token(token)}...] Cached watched sets for {content_type}") async def invalidate_watched_sets(self, token: str, content_type: str) -> None: @@ -229,6 +220,13 @@ async def get_profile_and_watched_sets( # Library Change Detection Methods + @staticmethod + def _extract_item_id(item) -> str: + """Extract item ID from either a typed StremioLibraryItem or a raw dict.""" + if hasattr(item, "id"): + return item.id + return item.get("_id", item.get("id", "")) + async def has_library_changed(self, token: str, content_type: str, library_items: list) -> bool: """ Check if library has changed since last profile build. @@ -241,18 +239,16 @@ async def has_library_changed(self, token: str, content_type: str, library_items Returns: True if library has changed, False otherwise """ - # Create hash of current library item IDs - current_ids = [item.get("_id", item.get("id", "")) for item in library_items] + current_ids = [self._extract_item_id(item) for item in library_items] current_hash = hashlib.md5("".join(sorted(current_ids)).encode()).hexdigest() - # Compare with stored hash stored_hash = await redis_service.get(self._library_hash_key(token, content_type)) if stored_hash is None: - # No stored hash, consider it changed return True - return current_hash != stored_hash.decode() if isinstance(stored_hash, bytes) else current_hash != stored_hash + # redis_service is configured with decode_responses=True so stored_hash is always str. + return current_hash != stored_hash async def update_library_hash(self, token: str, content_type: str, library_items: list) -> None: """ @@ -263,15 +259,15 @@ async def update_library_hash(self, token: str, content_type: str, library_items content_type: Content type (movie or series) library_items: Current library items list """ - current_ids = [item.get("_id", item.get("id", "")) for item in library_items] + current_ids = [self._extract_item_id(item) for item in library_items] current_hash = hashlib.md5("".join(sorted(current_ids)).encode()).hexdigest() hash_key = self._library_hash_key(token, content_type) build_time_key = self._last_profile_build_key(token, content_type) # Store hash and build timestamp - await redis_service.set(hash_key, current_hash) - await redis_service.set(build_time_key, str(time.time())) + await redis_service.set(hash_key, current_hash, USER_CACHE_TTL_SECONDS) + await redis_service.set(build_time_key, str(time.time()), USER_CACHE_TTL_SECONDS) logger.debug(f"[{redact_token(token)}...] Updated library hash for {content_type}") @@ -291,7 +287,7 @@ async def get_last_profile_build_time(self, token: str, content_type: str) -> in return None try: - return int(float(build_time.decode() if isinstance(build_time, bytes) else build_time)) + return int(float(build_time)) except (ValueError, TypeError): return None diff --git a/app/startup/migration.py b/app/startup/migration.py deleted file mode 100644 index 39f46d9..0000000 --- a/app/startup/migration.py +++ /dev/null @@ -1,239 +0,0 @@ -import base64 -import hashlib -import json -import traceback - -import httpx -import redis.asyncio as redis -from cryptography.fernet import Fernet -from loguru import logger - -from app.core.config import settings -from app.services.token_store import token_store - - -def decrypt_data(enc_json: str): - key_bytes = hashlib.sha256(settings.TOKEN_SALT.encode()).digest() - fernet_key = base64.urlsafe_b64encode(key_bytes) - cipher = Fernet(fernet_key) - if not isinstance(enc_json, str): - return {} - try: - decrypted = cipher.decrypt(enc_json.encode()).decode() - except Exception as exc: - logger.warning(f"Failed to decrypt data: {exc}") - raise exc - return json.loads(decrypted) - - -async def get_auth_key(username: str, password: str): - url = "https://api.strem.io/api/login" - payload = { - "email": username, - "password": password, - "type": "Login", - "facebook": False, - } - async with httpx.AsyncClient(timeout=10.0) as client: - result = await client.post(url, json=payload) - result.raise_for_status() - data = result.json() - auth_key = data.get("result", {}).get("authKey", "") - return auth_key - - -async def get_user_info(auth_key): - url = "https://api.strem.io/api/getUser" - payload = { - "type": "GetUser", - "authKey": auth_key, - } - async with httpx.AsyncClient(timeout=10.0) as client: - response = await client.post(url, json=payload) - response.raise_for_status() - data = response.json() - result = data.get("result", {}) - email = result.get("email") - user_id = result.get("_id") - return email, user_id - - -async def get_addons(auth_key: str): - url = "https://api.strem.io/api/addonCollectionGet" - payload = { - "type": "AddonCollectionGet", - "authKey": auth_key, - "update": True, - } - async with httpx.AsyncClient(timeout=10.0) as client: - result = await client.post(url, json=payload) - result.raise_for_status() - data = result.json() - error_payload = data.get("error") - if not error_payload and (data.get("code") and data.get("message")): - error_payload = data - - if error_payload: - message = "Invalid Stremio auth key." - if isinstance(error_payload, dict): - message = error_payload.get("message") or message - elif isinstance(error_payload, str): - message = error_payload or message - logger.warning(f"Addon collection request failed: {error_payload}") - raise ValueError(f"Stremio: {message}") - addons = data.get("result", {}).get("addons", []) - logger.info(f"Found {len(addons)} addons") - return addons - - -async def update_addon_url(auth_key: str, user_id: str): - addons = await get_addons(auth_key) - hostname = settings.HOST_NAME if settings.HOST_NAME.startswith("https") else f"https://{settings.HOST_NAME}" - for addon in addons: - if addon.get("manifest", {}).get("id") == settings.ADDON_ID: - addon["transportUrl"] = f"{hostname}/{user_id}/manifest.json" - - url = "https://api.strem.io/api/addonCollectionSet" - payload = { - "type": "AddonCollectionSet", - "authKey": auth_key, - "addons": addons, - } - - async with httpx.AsyncClient(timeout=10.0) as client: - result = await client.post(url, json=payload) - result.raise_for_status() - logger.info("Updated addon url") - return result.json().get("result", {}).get("success", False) - - -async def decode_old_payloads(encrypted_raw: str): - key_bytes = hashlib.sha256(settings.TOKEN_SALT.encode()).digest() - fernet_key = base64.urlsafe_b64encode(key_bytes) - cipher = Fernet(fernet_key) - decrypted_json = cipher.decrypt(encrypted_raw.encode()).decode("utf-8") - payload = json.loads(decrypted_json) - return payload - - -def encrypt_auth_key(auth_key: str) -> str: - # Delegate to TokenStore to keep encryption consistent everywhere - return token_store.encrypt_token(auth_key) - - -def prepare_default_payload(email, user_id): - return { - "email": email, - "user_id": user_id, - "settings": { - "catalogs": [ - {"id": "watchly.rec", "name": "Recommended", "enabled": True}, - {"id": "watchly.loved", "name": "Because You Loved", "enabled": True}, - {"id": "watchly.watched", "name": "Because You Watched", "enabled": True}, - {"id": "watchly.theme", "name": "Genre & Theme Collections", "enabled": True}, - ], - "language": "en", - "rpdb_key": "", - "excluded_movie_genres": [], - "excluded_series_genres": [], - }, - } - - -async def store_payload(client: redis.Redis, email: str, user_id: str, auth_key: str): - payload = prepare_default_payload(email, user_id) - logger.info(f"Storing payload for {user_id}: {payload}") - try: - # encrypt auth_key - if auth_key: - payload["authKey"] = encrypt_auth_key(auth_key) - key = f"{settings.REDIS_TOKEN_KEY}{user_id.strip()}" - await client.set(key, json.dumps(payload)) - except (redis.RedisError, OSError) as exc: - logger.warning(f"Failed to store payload for {key}: {exc}") - - -async def process_migration_key(redis_client: redis.Redis, key: str) -> bool: - try: - try: - data_raw = await redis_client.get(key) - except (redis.RedisError, OSError) as exc: - logger.warning(f"Failed to fetch payload for {key}: {exc}") - return False - - if not data_raw: - logger.warning(f"Failed to fetch payload for {key}: Empty data") - return False - - try: - payload = await decode_old_payloads(data_raw) - except (json.JSONDecodeError, Exception) as exc: - logger.warning(f"Failed to decode payload for key {key}: {exc}") - return False - - if payload.get("username") and payload.get("password"): - auth_key = await get_auth_key(payload["username"], payload["password"]) - elif payload.get("authKey"): - auth_key = payload.get("authKey") - else: - logger.warning(f"Failed to migrate {key}") - await redis_client.delete(key) - return False - - email, user_id = await get_user_info(auth_key) - if not email or not user_id: - logger.warning(f"Failed to migrate {key}") - await redis_client.delete(key) - return False - - new_payload = prepare_default_payload(email, user_id) - if auth_key: - new_payload["authKey"] = encrypt_auth_key(auth_key) - - new_key = f"{settings.REDIS_TOKEN_KEY}{user_id.strip()}" - payload_json = json.dumps(new_payload) - - if settings.TOKEN_TTL_SECONDS and settings.TOKEN_TTL_SECONDS > 0: - set_success = await redis_client.set(new_key, payload_json, ex=settings.TOKEN_TTL_SECONDS, nx=True) - if set_success: - logger.info( - f"Stored encrypted credential payload with TTL {settings.TOKEN_TTL_SECONDS} seconds (SETNX)" - ) - else: - set_success = await redis_client.setnx(new_key, payload_json) - if set_success: - logger.info("Stored encrypted credential payload without expiration (SETNX)") - - if not set_success: - logger.info(f"Credential payload for {new_key} already exists, not overwriting.") - - await redis_client.delete(key) - logger.info(f"Migrated {key} to {new_key}") - return True - - except Exception as exc: - await redis_client.delete(key) - traceback.print_exc() - logger.warning(f"Failed to migrate {key}: {exc}") - return False - - -async def migrate_tokens(): - total_tokens = 0 - failed_tokens = 0 - success_tokens = 0 - try: - redis_client = await token_store._get_client() - except (redis.RedisError, OSError) as exc: - logger.warning(f"Failed to connect to Redis: {exc}") - return - - pattern = f"{settings.REDIS_TOKEN_KEY}*" - async for key in redis_client.scan_iter(match=pattern): - total_tokens += 1 - if await process_migration_key(redis_client, key): - success_tokens += 1 - else: - failed_tokens += 1 - - logger.info(f"[STATS] Total: {total_tokens}, Failed: {failed_tokens}, Success: {success_tokens}") diff --git a/app/static/js/main.js b/app/static/js/main.js index d79b92a..0dbcdfc 100644 --- a/app/static/js/main.js +++ b/app/static/js/main.js @@ -1,14 +1,14 @@ // Main entry point - initializes all modules -import { defaultCatalogs } from './constants.js'; -import { showToast, initializeFooter, initializeKofi } from './modules/ui.js'; +import { createAppState, resetAppState } from './state.js'; +import { initializeFooter, initializeKofi } from './modules/ui.js'; import { initializeNavigation, switchSection, lockNavigationForLoggedOut, initializeMobileNav, updateMobileLayout, unlockNavigation } from './modules/navigation.js'; import { initializeAuth, setStremioLoggedOutState } from './modules/auth.js'; -import { initializeCatalogList, renderCatalogList, getCatalogs, setCatalogs } from './modules/catalog.js'; -import { initializeForm, clearErrors } from './modules/form.js'; +import { initializeCatalogList, renderCatalogList } from './modules/catalog.js'; +import { initializeForm, clearErrors, refreshYearSlider } from './modules/form.js'; +import { initializeAccountsUI } from './modules/accounts.js'; -// Initialize catalogs state -let catalogsState = JSON.parse(JSON.stringify(defaultCatalogs)); +const appState = createAppState(); // DOM Elements const configForm = document.getElementById('configForm'); @@ -22,6 +22,7 @@ const emailInput = document.getElementById('emailInput'); const passwordInput = document.getElementById('passwordInput'); const emailPwdContinueBtn = document.getElementById('emailPwdContinueBtn'); const languageSelect = document.getElementById('languageSelect'); +const accountsNextBtn = document.getElementById('accountsNextBtn'); const configNextBtn = document.getElementById('configNextBtn'); const catalogsNextBtn = document.getElementById('catalogsNextBtn'); const successResetBtn = document.getElementById('successResetBtn'); @@ -50,26 +51,19 @@ const mainEl = document.querySelector('main'); // Reset App Function function resetApp() { if (configForm) configForm.reset(); + resetAppState(appState); clearErrors(); - // Reset Navigation is now Back to Welcome - switchSection('welcome'); - - // Lock Navs - Object.keys(navItems).forEach(key => { - if (key !== 'login' && key !== 'welcome') { - if (navItems[key]) navItems[key].classList.add('disabled'); - } - }); - // Reset Stremio State setStremioLoggedOutState(); // Reset catalogs - catalogsState = JSON.parse(JSON.stringify(defaultCatalogs)); - setCatalogs(catalogsState); renderCatalogList(); + // Reset Navigation is now Back to Welcome + switchSection(appState.ui.currentSection); + lockNavigationForLoggedOut(); + // Show Form if (configForm) configForm.classList.remove('hidden'); if (sections.success) sections.success.classList.add('hidden'); @@ -95,7 +89,7 @@ function initializeWelcomeFlow() { // Initialize everything document.addEventListener('DOMContentLoaded', () => { // Start at Welcome - switchSection('welcome'); + switchSection(appState.ui.currentSection); initializeWelcomeFlow(); // Initialize all modules @@ -103,19 +97,28 @@ document.addEventListener('DOMContentLoaded', () => { navItems, sections, mainEl - }); + }, appState); // By default, ensure logged-out users see only Welcome/Login lockNavigationForLoggedOut(); - // Initialize catalog management - set catalogs first - setCatalogs(catalogsState); - initializeCatalogList( - { catalogList }, + initializeAccountsUI({ switchSection }); + + initializeCatalogList({ catalogList }, appState); + + // Initialize form handling + initializeForm( { - catalogs: catalogsState, - renderCatalogList - } + configForm, + submitBtn, + emailInput, + passwordInput, + languageSelect, + movieGenreList, + seriesGenreList + }, + appState, + { resetApp } ); // Initialize authentication @@ -128,27 +131,14 @@ document.addEventListener('DOMContentLoaded', () => { emailPwdContinueBtn, languageSelect }, + appState, { - getCatalogs, renderCatalogList, - resetApp - } - ); - - // Initialize form handling - initializeForm( - { - configForm, - submitBtn, - emailInput, - passwordInput, - languageSelect, - movieGenreList, - seriesGenreList - }, - { - getCatalogs, - resetApp + resetApp, + switchSection, + unlockNavigation, + lockNavigationForLoggedOut, + updateYearSlider: refreshYearSlider } ); @@ -165,6 +155,9 @@ document.addEventListener('DOMContentLoaded', () => { window.addEventListener('orientationchange', updateMobileLayout); // Next Buttons + if (accountsNextBtn) accountsNextBtn.addEventListener('click', () => { + if (!accountsNextBtn.disabled) switchSection('config'); + }); if (configNextBtn) configNextBtn.addEventListener('click', () => switchSection('catalogs')); if (catalogsNextBtn) catalogsNextBtn.addEventListener('click', () => switchSection('install')); @@ -173,8 +166,3 @@ document.addEventListener('DOMContentLoaded', () => { if (resetBtn) resetBtn.addEventListener('click', resetApp); if (successResetBtn) successResetBtn.addEventListener('click', resetApp); }); - -// Make resetApp available globally for auth module -window.resetApp = resetApp; -window.switchSection = switchSection; -window.unlockNavigation = unlockNavigation; diff --git a/app/static/js/modules/accounts.js b/app/static/js/modules/accounts.js new file mode 100644 index 0000000..5280228 --- /dev/null +++ b/app/static/js/modules/accounts.js @@ -0,0 +1,137 @@ +// Accounts page state + Watch History Source segmented control on Configure. +// +// Each provider card has two views — disconnected (login UI) and connected +// (status + Disconnect) — toggled via setProviderConnected. The Configure +// page's source picker is always interactive: clicking a provider that isn't +// connected jumps the user to the matching card in Accounts instead of +// silently doing nothing. + +import { showToast } from './ui.js'; + +const ACTIVE_CLASSES = ['bg-white/10', 'text-white', 'shadow-sm']; +const ACTIVE_BORDER_CLASS = 'border-white/20'; +const INACTIVE_CLASSES = ['text-slate-400', 'hover:text-white', 'hover:bg-white/5']; +const INACTIVE_BORDER_CLASS = 'border-transparent'; + +const PROVIDER_LABELS = { stremio: 'Stremio', trakt: 'Trakt', simkl: 'Simkl' }; + +let switchSectionFn = null; +const connectedState = { stremio: false, trakt: false, simkl: false }; + +export function initializeAccountsUI({ switchSection } = {}) { + switchSectionFn = switchSection || null; + + document.querySelectorAll('.source-btn').forEach(btn => { + btn.addEventListener('click', () => onSourceButtonClick(btn.dataset.sourceBtn)); + }); + + document.querySelectorAll('.account-link').forEach(link => { + link.addEventListener('click', () => goToAccounts()); + }); + + syncAccountsNextButton(); +} + +export function setStremioConnected(connected) { + connectedState.stremio = connected; + setProviderDot('stremio', connected); + setProviderView('stremio', connected); + + if (!connected) { + // Cascade: optional providers reset visually when Stremio drops. + // Tokens in window._watchlyOAuth and inline status text are managed + // by callers (resetApp / OAuth handlers), not here. + setProviderConnected('trakt', false); + setProviderConnected('simkl', false); + setWatchHistorySource('stremio'); + } + + syncAccountsNextButton(); +} + +export function setProviderConnected(provider, connected) { + if (provider === 'stremio') { + setStremioConnected(connected); + return; + } + if (provider !== 'trakt' && provider !== 'simkl') return; + + connectedState[provider] = connected; + setProviderDot(provider, connected); + setProviderView(provider, connected); + + if (!connected && currentSource() === provider) { + setWatchHistorySource('stremio'); + } +} + +export function setWatchHistorySource(value) { + const hidden = document.getElementById('watchHistorySource'); + if (hidden) hidden.value = value; + document.querySelectorAll('.source-btn').forEach(btn => { + applyActive(btn, btn.dataset.sourceBtn === value); + }); +} + +function onSourceButtonClick(provider) { + if (provider !== 'stremio' && !connectedState[provider]) { + showToast(`Connect ${PROVIDER_LABELS[provider]} in Accounts to use it as your watch history source.`, 'info', 4000); + goToAccounts(provider); + return; + } + setWatchHistorySource(provider); +} + +function goToAccounts(scrollTo) { + if (typeof switchSectionFn === 'function') { + switchSectionFn('login'); + } + if (scrollTo) { + // Defer until the section is visible after switchSection completes. + requestAnimationFrame(() => { + const target = document.getElementById(`provider-${scrollTo}`); + if (target) target.scrollIntoView({ behavior: 'smooth', block: 'start' }); + }); + } +} + +function syncAccountsNextButton() { + const btn = document.getElementById('accountsNextBtn'); + if (!btn) return; + btn.disabled = !connectedState.stremio; +} + +function setProviderView(provider, connected) { + const disconnected = document.querySelector(`[data-provider-view="disconnected"][data-provider-for="${provider}"]`); + const connectedEl = document.querySelector(`[data-provider-view="connected"][data-provider-for="${provider}"]`); + if (disconnected) disconnected.classList.toggle('hidden', connected); + if (connectedEl) connectedEl.classList.toggle('hidden', !connected); +} + +function currentSource() { + const hidden = document.getElementById('watchHistorySource'); + return hidden ? hidden.value : ''; +} + +function applyActive(btn, isActive) { + btn.classList.remove(...ACTIVE_CLASSES, ...INACTIVE_CLASSES, ACTIVE_BORDER_CLASS, INACTIVE_BORDER_CLASS); + if (isActive) { + btn.classList.add(...ACTIVE_CLASSES, ACTIVE_BORDER_CLASS); + } else { + btn.classList.add(...INACTIVE_CLASSES, INACTIVE_BORDER_CLASS); + } +} + +function setProviderDot(provider, connected) { + const dot = document.querySelector(`[data-account-dot="${provider}"]`); + if (dot) { + dot.classList.toggle('bg-green-400', connected); + dot.classList.toggle('bg-slate-500', !connected); + } + + const pip = document.querySelector(`[data-source-pip="${provider}"]`); + if (pip) { + pip.classList.toggle('bg-green-400', connected); + pip.classList.toggle('bg-slate-600', !connected); + } +} diff --git a/app/static/js/modules/auth-storage.js b/app/static/js/modules/auth-storage.js new file mode 100644 index 0000000..14a8ee9 --- /dev/null +++ b/app/static/js/modules/auth-storage.js @@ -0,0 +1,45 @@ +const STORAGE_KEY = 'watchly_auth'; +const EXPIRY_DAYS = 30; + +export function saveAuthToStorage(authData) { + try { + const expiryDate = new Date(); + expiryDate.setDate(expiryDate.getDate() + EXPIRY_DAYS); + const data = { + ...authData, + expiresAt: expiryDate.getTime() + }; + localStorage.setItem(STORAGE_KEY, JSON.stringify(data)); + } catch (e) { + console.warn('Failed to save auth to localStorage:', e); + } +} + +export function getAuthFromStorage() { + try { + const stored = localStorage.getItem(STORAGE_KEY); + if (!stored) return null; + + const data = JSON.parse(stored); + const now = Date.now(); + + if (data.expiresAt && data.expiresAt < now) { + clearAuthFromStorage(); + return null; + } + + return data; + } catch (e) { + console.warn('Failed to read auth from localStorage:', e); + clearAuthFromStorage(); + return null; + } +} + +export function clearAuthFromStorage() { + try { + localStorage.removeItem(STORAGE_KEY); + } catch (e) { + console.warn('Failed to clear auth from localStorage:', e); + } +} diff --git a/app/static/js/modules/auth-ui.js b/app/static/js/modules/auth-ui.js new file mode 100644 index 0000000..09fa411 --- /dev/null +++ b/app/static/js/modules/auth-ui.js @@ -0,0 +1,86 @@ +function getInitialsFromEmail(email) { + if (!email) return '?'; + + const username = email.split('@')[0]; + const parts = username.split(/[._-]/); + + if (parts.length >= 2) { + return (parts[0][0] + parts[1][0]).toUpperCase(); + } + return username.substring(0, 2).toUpperCase(); +} + +export function updateInstallMode(existingUser) { + const installHeader = document.querySelector('#sect-install h2'); + const installDesc = document.querySelector('#sect-install p'); + const btnText = document.querySelector('#submitBtn .btn-text'); + + if (existingUser) { + if (installHeader) installHeader.textContent = 'Update Settings'; + if (installDesc) installDesc.textContent = 'Update your preferences and re-install.'; + if (btnText) btnText.textContent = 'Update & Re-Install'; + return; + } + + if (installHeader) installHeader.textContent = 'Save & Install'; + if (installDesc) installDesc.textContent = 'Save your settings and install the addon.'; + if (btnText) btnText.textContent = 'Save & Install'; +} + +export function showUserProfile(email) { + const userProfileWrapper = document.getElementById('user-profile-dropdown-wrapper'); + const userEmail = document.getElementById('user-email'); + const userAvatar = document.getElementById('user-avatar'); + const loginStatusEmail = document.getElementById('loginStatusEmail'); + const loginStatusAvatar = document.getElementById('loginStatusAvatar'); + + const initials = getInitialsFromEmail(email); + + if (userProfileWrapper && userEmail && userAvatar) { + userEmail.textContent = email; + userAvatar.textContent = initials; + userProfileWrapper.classList.remove('hidden'); + } + + if (loginStatusEmail) loginStatusEmail.textContent = email; + if (loginStatusAvatar) loginStatusAvatar.textContent = initials; +} + +export function hideUserProfile() { + const userProfileWrapper = document.getElementById('user-profile-dropdown-wrapper'); + const dropdown = document.getElementById('user-profile-dropdown'); + + if (userProfileWrapper) { + userProfileWrapper.classList.add('hidden'); + } + + if (dropdown) { + dropdown.classList.add('hidden'); + const chevron = document.getElementById('user-profile-chevron'); + if (chevron) { + chevron.style.transform = 'rotate(0deg)'; + } + } +} + +export function renderLoggedInControls({ authKey }) { + const authKeyInput = document.getElementById('authKey'); + if (authKeyInput) authKeyInput.value = authKey || ''; +} + +export function renderLoggedOutControls({ emailInput, passwordInput }) { + const authKeyInput = document.getElementById('authKey'); + if (authKeyInput) authKeyInput.value = ''; + + if (emailInput) emailInput.value = ''; + if (passwordInput) passwordInput.value = ''; + + const toggleBtn = document.querySelector('.toggle-btn[data-target="passwordInput"]'); + const pwd = document.getElementById('passwordInput'); + if (toggleBtn && pwd) { + pwd.type = 'password'; + toggleBtn.setAttribute('title', 'Show'); + toggleBtn.setAttribute('aria-label', 'Show password'); + toggleBtn.innerHTML = '<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M1 12s4-7 11-7 11 7 11 7-4 7-11 7-11-7-11-7z"/><circle cx="12" cy="12" r="3"/></svg>'; + } +} diff --git a/app/static/js/modules/auth.js b/app/static/js/modules/auth.js index d91fd31..241d1d5 100644 --- a/app/static/js/modules/auth.js +++ b/app/static/js/modules/auth.js @@ -1,7 +1,19 @@ // Authentication Logic import { showToast } from './ui.js'; -import { switchSection, unlockNavigation, lockNavigationForLoggedOut } from './navigation.js'; +import { clearAuthFromStorage, getAuthFromStorage, saveAuthToStorage } from './auth-storage.js'; +import { + hideUserProfile, + renderLoggedInControls, + renderLoggedOutControls, + showUserProfile, + updateInstallMode +} from './auth-ui.js'; +import { + setProviderConnected, + setStremioConnected, + setWatchHistorySource, +} from './accounts.js'; // DOM Elements - will be initialized let stremioLoginBtn = null; @@ -10,69 +22,26 @@ let emailInput = null; let passwordInput = null; let emailPwdContinueBtn = null; let languageSelect = null; -let getCatalogs = null; +let appState = null; let renderCatalogList = null; let resetApp = null; +let switchSection = null; +let unlockNavigation = null; +let updateYearSlider = null; -// LocalStorage keys -const STORAGE_KEY = 'watchly_auth'; -const EXPIRY_DAYS = 30; - -// LocalStorage helper functions -function saveAuthToStorage(authData) { - try { - const expiryDate = new Date(); - expiryDate.setDate(expiryDate.getDate() + EXPIRY_DAYS); - const data = { - ...authData, - expiresAt: expiryDate.getTime() - }; - localStorage.setItem(STORAGE_KEY, JSON.stringify(data)); - } catch (e) { - console.warn('Failed to save auth to localStorage:', e); - } -} - -function getAuthFromStorage() { - try { - const stored = localStorage.getItem(STORAGE_KEY); - if (!stored) return null; - - const data = JSON.parse(stored); - const now = Date.now(); - - // Check if expired - if (data.expiresAt && data.expiresAt < now) { - clearAuthFromStorage(); - return null; - } - - return data; - } catch (e) { - console.warn('Failed to read auth from localStorage:', e); - clearAuthFromStorage(); - return null; - } -} - -function clearAuthFromStorage() { - try { - localStorage.removeItem(STORAGE_KEY); - } catch (e) { - console.warn('Failed to clear auth from localStorage:', e); - } -} - -export function initializeAuth(domElements, catalogState) { +export function initializeAuth(domElements, state, actions) { stremioLoginBtn = domElements.stremioLoginBtn; stremioLoginText = domElements.stremioLoginText; emailInput = domElements.emailInput; passwordInput = domElements.passwordInput; emailPwdContinueBtn = domElements.emailPwdContinueBtn; languageSelect = domElements.languageSelect; - getCatalogs = catalogState.getCatalogs; - renderCatalogList = catalogState.renderCatalogList; - resetApp = catalogState.resetApp; + appState = state; + renderCatalogList = actions.renderCatalogList; + resetApp = actions.resetApp; + switchSection = actions.switchSection; + unlockNavigation = actions.unlockNavigation; + updateYearSlider = actions.updateYearSlider; // Initialize logout buttons initializeLoginStatusLogoutButton(); @@ -205,7 +174,7 @@ async function initializeStremioLogin() { const authKey = urlParams.get('key') || urlParams.get('authKey'); if (authKey) { - // Logged In -> Unlock and move to config + // Logged In -> Unlock; stay on Accounts so the user can connect optional providers setStremioLoggedInState(authKey); try { @@ -213,7 +182,7 @@ async function initializeStremioLogin() { // Save auth key to localStorage for persistent login saveAuthToStorage({ authKey }); unlockNavigation(); - switchSection('config'); + switchSection('login'); } catch (error) { showToast(error.message, "error"); clearAuthFromStorage(); @@ -302,7 +271,7 @@ async function fetchStremioIdentity(authKey) { if (s.popularity && popularitySelect) popularitySelect.value = s.popularity; if (s.year_min && yearMinInput) yearMinInput.value = s.year_min; if (s.year_max && yearMaxInput) yearMaxInput.value = s.year_max; - if (window.updateYearSlider) window.updateYearSlider(); + if (updateYearSlider) updateYearSlider(); const sortingOrderSelect = document.getElementById('sortingOrderSelect'); if (s.sorting_order && sortingOrderSelect) sortingOrderSelect.value = s.sorting_order; @@ -335,6 +304,9 @@ async function fetchStremioIdentity(authKey) { const geminiApiKeyInput = document.getElementById('geminiApiKey'); if (s.gemini_api_key && geminiApiKeyInput) geminiApiKeyInput.value = s.gemini_api_key; + // Watch History Source + OAuth tokens + restoreWatchHistoryState(s); + // Genres (Checked = Excluded) document.querySelectorAll('input[name="movie-genre"]').forEach(cb => cb.checked = false); document.querySelectorAll('input[name="series-genre"]').forEach(cb => cb.checked = false); @@ -350,7 +322,7 @@ async function fetchStremioIdentity(authKey) { // Catalogs if (s.catalogs && Array.isArray(s.catalogs)) { - const catalogs = getCatalogs ? getCatalogs() : []; + const catalogs = appState ? appState.catalogs : []; s.catalogs.forEach(remote => { const local = catalogs.find(c => c.id === remote.id); if (local) { @@ -367,24 +339,11 @@ async function fetchStremioIdentity(authKey) { } // Update UI for "Update Mode" - const installHeader = document.querySelector('#sect-install h2'); - const installDesc = document.querySelector('#sect-install p'); - if (installHeader) installHeader.textContent = "Update Settings"; - if (installDesc) installDesc.textContent = "Update your preferences and re-install."; - - const btnText = document.querySelector('#submitBtn .btn-text'); - if (btnText) btnText.textContent = "Update & Re-Install"; + updateInstallMode(true); } else { // New Account showToast(`Welcome! Setting up new account for ${userDisplay}`, "success", 5000); - - const installHeader = document.querySelector('#sect-install h2'); - const installDesc = document.querySelector('#sect-install p'); - if (installHeader) installHeader.textContent = "Save & Install"; - if (installDesc) installDesc.textContent = "Save your settings and install the addon."; - - const btnText = document.querySelector('#submitBtn .btn-text'); - if (btnText) btnText.textContent = "Save & Install"; + updateInstallMode(false); } } @@ -416,9 +375,8 @@ function initializeEmailPasswordLogin() { saveAuthToStorage({ email, password: pwd }); // Mark as logged-in (disables inputs and flips button to Logout) setStremioLoggedInState(''); - // Proceed to config + // Stay on Accounts so the user can connect optional providers unlockNavigation(); - switchSection('config'); } catch (e) { showEmailPwdError(e.message || 'Login failed'); clearAuthFromStorage(); @@ -461,36 +419,21 @@ function isValidEmail(value) { } export function setStremioLoggedInState(authKey) { - if (!stremioLoginBtn) return; - stremioLoginText.textContent = 'Logout'; - stremioLoginBtn.setAttribute('data-action', 'logout'); - stremioLoginBtn.classList.remove('bg-stremio', 'hover:bg-stremio-hover', 'hover:bg-white', 'hover:text-black', 'hover:border-white/10', 'border-stremio-border'); - stremioLoginBtn.classList.add('bg-red-600', 'hover:bg-red-700', 'border-red-700', 'shadow-red-900/20', 'text-white'); - - // Pre-fill hidden AuthKey for submission - const authKeyInput = document.getElementById('authKey'); - if (authKeyInput) authKeyInput.value = authKey; + if (appState) { + appState.auth.loggedIn = true; + appState.auth.authKey = authKey || ''; + } - // Hide email/password login block and its disclaimer; keep only Logout button visible - try { - const emailPwdSection = document.getElementById('emailPwdSection'); - const disclaimer = document.getElementById('emailPwdDisclaimer'); - const divider = document.getElementById('emailPwdDivider'); - if (emailPwdSection) emailPwdSection.classList.add('hidden'); - if (disclaimer) disclaimer.classList.add('hidden'); - if (divider) divider.classList.add('hidden'); - } catch (e) { /* noop */ } + renderLoggedInControls({ stremioLoginBtn, stremioLoginText, authKey }); + setStremioConnected(true); } export function setStremioLoggedOutState() { - if (!stremioLoginBtn) return; - stremioLoginText.textContent = 'Login with Stremio'; - stremioLoginBtn.removeAttribute('data-action'); - stremioLoginBtn.classList.add('bg-stremio', 'hover:bg-white', 'hover:text-black', 'hover:border-white/10', 'border-stremio-border', 'text-white'); - stremioLoginBtn.classList.remove('bg-red-600', 'hover:bg-red-700', 'border-red-700', 'shadow-red-900/20'); - - const authKeyInput = document.getElementById('authKey'); - if (authKeyInput) authKeyInput.value = ''; + if (appState) { + appState.auth.loggedIn = false; + appState.auth.authKey = ''; + appState.auth.userDisplay = null; + } // Clear stored auth credentials clearAuthFromStorage(); @@ -498,105 +441,83 @@ export function setStremioLoggedOutState() { // Hide user profile hideUserProfile(); - // Restore email/password login block visibility and clear inputs - try { - const emailPwdSection = document.getElementById('emailPwdSection'); - const disclaimer = document.getElementById('emailPwdDisclaimer'); - const divider = document.getElementById('emailPwdDivider'); - if (emailPwdSection) emailPwdSection.classList.remove('hidden'); - if (disclaimer) disclaimer.classList.remove('hidden'); - if (divider) divider.classList.remove('hidden'); - if (emailInput) { emailInput.value = ''; } - if (passwordInput) { passwordInput.value = ''; } - // Reset password toggle button state to hidden - const toggleBtn = document.querySelector('.toggle-btn[data-target="passwordInput"]'); - const pwd = document.getElementById('passwordInput'); - if (toggleBtn && pwd) { - pwd.type = 'password'; - toggleBtn.setAttribute('title', 'Show'); - toggleBtn.setAttribute('aria-label', 'Show password'); - toggleBtn.innerHTML = '<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M1 12s4-7 11-7 11 7 11 7-4 7-11 7-11-7-11-7z"/><circle cx="12" cy="12" r="3"/></svg>'; - } - } catch (e) { /* noop */ } + renderLoggedOutControls({ stremioLoginBtn, stremioLoginText, emailInput, passwordInput }); + setStremioConnected(false); } -// User Profile Functions -function showUserProfile(email) { - const userProfileWrapper = document.getElementById('user-profile-dropdown-wrapper'); - const userEmail = document.getElementById('user-email'); - const userAvatar = document.getElementById('user-avatar'); - - // Login status section elements - const loginStatusSection = document.getElementById('loginStatusSection'); - const loginStatusEmail = document.getElementById('loginStatusEmail'); - const loginStatusAvatar = document.getElementById('loginStatusAvatar'); - - if (!userProfileWrapper || !userEmail || !userAvatar) return; +// Restore Watch History Source and OAuth connected state from saved settings +function restoreWatchHistoryState(settings) { + window._watchlyOAuth = window._watchlyOAuth || {}; - // Set email - userEmail.textContent = email; - - // Generate avatar initials from email - const initials = getInitialsFromEmail(email); - userAvatar.textContent = initials; - - // Show the profile dropdown wrapper - userProfileWrapper.classList.remove('hidden'); - - // Show login status section and update it - if (loginStatusSection && loginStatusEmail && loginStatusAvatar) { - loginStatusEmail.textContent = email; - loginStatusAvatar.textContent = initials; - loginStatusSection.classList.remove('hidden'); - } - - // Hide the login form when logged in - const loginFormCard = document.getElementById('loginFormCard'); - if (loginFormCard) loginFormCard.classList.add('hidden'); -} - -function hideUserProfile() { - const userProfileWrapper = document.getElementById('user-profile-dropdown-wrapper'); - const dropdown = document.getElementById('user-profile-dropdown'); - const loginStatusSection = document.getElementById('loginStatusSection'); - - if (userProfileWrapper) { - userProfileWrapper.classList.add('hidden'); + if (settings.trakt_access_token) { + window._watchlyOAuth.trakt = { + access_token: settings.trakt_access_token, + refresh_token: settings.trakt_refresh_token || '', + expires_at: settings.trakt_token_expires_at || 0, + }; + const traktStatus = document.getElementById('traktStatus'); + if (traktStatus) { + traktStatus.textContent = 'Connected'; + traktStatus.classList.remove('text-slate-500'); + traktStatus.classList.add('text-green-400'); + } + const traktLogoutBtn = document.getElementById('traktLogoutBtn'); + if (traktLogoutBtn) traktLogoutBtn.classList.remove('hidden'); + setProviderConnected('trakt', true); + validateAndShowTraktUser(settings.trakt_access_token); } - // Close dropdown if open - if (dropdown) { - dropdown.classList.add('hidden'); - const chevron = document.getElementById('user-profile-chevron'); - if (chevron) { - chevron.style.transform = 'rotate(0deg)'; + if (settings.simkl_access_token) { + window._watchlyOAuth.simkl = { + access_token: settings.simkl_access_token, + }; + const simklSyncStatus = document.getElementById('simklSyncStatus'); + if (simklSyncStatus) { + simklSyncStatus.textContent = 'Connected'; + simklSyncStatus.classList.remove('text-slate-500'); + simklSyncStatus.classList.add('text-green-400'); } + const simklSyncLogoutBtn = document.getElementById('simklSyncLogoutBtn'); + if (simklSyncLogoutBtn) simklSyncLogoutBtn.classList.remove('hidden'); + setProviderConnected('simkl', true); + validateAndShowSimklUser(settings.simkl_access_token); } - // Hide login status section - if (loginStatusSection) { - loginStatusSection.classList.add('hidden'); + if (settings.watch_history_source) { + setWatchHistorySource(settings.watch_history_source); } - - // Show the login form when logged out - const loginFormCard = document.getElementById('loginFormCard'); - if (loginFormCard) loginFormCard.classList.remove('hidden'); } -function getInitialsFromEmail(email) { - if (!email) return '?'; - - // If it's an email, get the part before @ - const username = email.split('@')[0]; - - // Split by common separators (., _, -) - const parts = username.split(/[._-]/); +async function validateAndShowTraktUser(accessToken) { + try { + const res = await fetch('/trakt/validation', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ access_token: accessToken }), + }); + const data = await res.json(); + const traktStatus = document.getElementById('traktStatus'); + if (data.valid && traktStatus) { + traktStatus.textContent = data.message; // "Connected as username" + } + } catch (e) { + // Silently ignore — status already shows "Connected" + } +} - if (parts.length >= 2) { - // Take first letter of first two parts - return (parts[0][0] + parts[1][0]).toUpperCase(); - } else { - // Take first two letters of username - return username.substring(0, 2).toUpperCase(); +async function validateAndShowSimklUser(accessToken) { + try { + const res = await fetch('/simkl-sync/validation', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify({ access_token: accessToken }), + }); + const data = await res.json(); + const simklSyncStatus = document.getElementById('simklSyncStatus'); + if (data.valid && simklSyncStatus) { + simklSyncStatus.textContent = data.message; // "Connected as username" + } + } catch (e) { + // Silently ignore — status already shows "Connected" } } diff --git a/app/static/js/modules/catalog.js b/app/static/js/modules/catalog.js index d35790c..311ef73 100644 --- a/app/static/js/modules/catalog.js +++ b/app/static/js/modules/catalog.js @@ -2,33 +2,19 @@ import { escapeHtml } from './ui.js'; -let catalogs = []; let catalogList = null; +let appState = null; -export function initializeCatalogList(domElements, catalogState) { +export function initializeCatalogList(domElements, state) { catalogList = domElements.catalogList; - // Use the catalogs array from catalogState (shared reference) - if (catalogState && catalogState.catalogs) { - // Replace the array contents to maintain reference - catalogs.length = 0; - catalogs.push(...catalogState.catalogs); - } + appState = state; renderCatalogList(); } -export function setCatalogs(newCatalogs) { - catalogs.length = 0; - catalogs.push(...newCatalogs); -} - -export function getCatalogs() { - return catalogs; -} - export function renderCatalogList() { - if (!catalogList) return; + if (!catalogList || !appState) return; catalogList.innerHTML = ''; - catalogs.forEach((cat, index) => { + appState.catalogs.forEach((cat, index) => { const item = createCatalogItem(cat, index); catalogList.appendChild(item); }); @@ -36,13 +22,13 @@ export function renderCatalogList() { function moveCatalogUp(index) { if (index === 0) return; - [catalogs[index], catalogs[index - 1]] = [catalogs[index - 1], catalogs[index]]; + [appState.catalogs[index], appState.catalogs[index - 1]] = [appState.catalogs[index - 1], appState.catalogs[index]]; renderCatalogList(); } function moveCatalogDown(index) { - if (index === catalogs.length - 1) return; - [catalogs[index], catalogs[index + 1]] = [catalogs[index + 1], catalogs[index]]; + if (index === appState.catalogs.length - 1) return; + [appState.catalogs[index], appState.catalogs[index + 1]] = [appState.catalogs[index + 1], appState.catalogs[index]]; renderCatalogList(); } @@ -54,7 +40,10 @@ function createCatalogItem(cat, index) { item.setAttribute('data-id', cat.id); item.setAttribute('data-index', index); - const isRenamable = cat.id !== 'watchly.theme'; + // watchly.theme builds names from genres/keywords at runtime; watchly.item + // builds them from the seed bucket ("Because you loved/watched"). Both + // would discard a user-supplied name, so the rename button is hidden. + const isRenamable = cat.id !== 'watchly.theme' && cat.id !== 'watchly.item'; // Determine active mode for toggle buttons const enabledMovie = cat.enabledMovie !== false; @@ -73,7 +62,7 @@ function createCatalogItem(cat, index) { <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="m18 15-6-6-6 6"/></svg> </button> <div class="h-9 flex items-center"> - <button type="button" class="action-btn move-down p-2 text-blue-400 bg-blue-500/20 hover:bg-blue-500/30 border border-blue-500/40 hover:border-blue-400/60 rounded-lg transition-all disabled:opacity-30 disabled:hover:bg-blue-500/20 disabled:cursor-not-allowed shadow-sm hover:shadow-md hover:shadow-blue-500/20" title="Move down" ${index === catalogs.length - 1 ? 'disabled' : ''}> + <button type="button" class="action-btn move-down p-2 text-blue-400 bg-blue-500/20 hover:bg-blue-500/30 border border-blue-500/40 hover:border-blue-400/60 rounded-lg transition-all disabled:opacity-30 disabled:hover:bg-blue-500/20 disabled:cursor-not-allowed shadow-sm hover:shadow-md hover:shadow-blue-500/20" title="Move down" ${index === appState.catalogs.length - 1 ? 'disabled' : ''}> <svg xmlns="http://www.w3.org/2000/svg" width="16" height="16" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="m6 9 6 6 6-6"/></svg> </button> </div> diff --git a/app/static/js/modules/field-helpers.js b/app/static/js/modules/field-helpers.js new file mode 100644 index 0000000..da901e1 --- /dev/null +++ b/app/static/js/modules/field-helpers.js @@ -0,0 +1,106 @@ +const LOADING_ICON = '<svg class="w-5 h-5 animate-spin" fill="none" stroke="currentColor" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>'; + +export function setValidationMessage(validationMessage, message, type) { + if (!validationMessage) return; + validationMessage.textContent = message; + validationMessage.className = `mt-2 text-xs ${type === 'success' ? 'text-green-400' : 'text-red-400'}`; + validationMessage.classList.remove('hidden'); +} + +export function clearValidationMessage(validationMessage) { + if (validationMessage) { + validationMessage.classList.add('hidden'); + } +} + +export function initializeEyeToggle({ input, toggleBtn, eyeIcon, eyeOffIcon }) { + if (!input || !toggleBtn || !eyeIcon || !eyeOffIcon) return; + + toggleBtn.addEventListener('click', () => { + const isPassword = input.type === 'password'; + input.type = isPassword ? 'text' : 'password'; + eyeIcon.classList.toggle('hidden', !isPassword); + eyeOffIcon.classList.toggle('hidden', isPassword); + }); +} + +export function initializePasswordToggleButton(selector = '.toggle-btn') { + document.querySelectorAll(selector).forEach(btn => { + btn.addEventListener('click', () => { + const targetId = btn.getAttribute('data-target'); + const input = document.getElementById(targetId); + if (!input) return; + const isHidden = input.type === 'password'; + input.type = isHidden ? 'text' : 'password'; + if (isHidden) { + btn.setAttribute('title', 'Hide'); + btn.setAttribute('aria-label', 'Hide password'); + btn.innerHTML = '<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M17.94 17.94A10.94 10.94 0 0 1 12 20c-7 0-11-8-11-8a21.77 21.77 0 0 1 5.06-6.17M9.9 4.24A10.94 10.94 0 0 1 12 4c7 0 11 8 11 8a21.8 21.8 0 0 1-3.22 4.31"/><path d="M1 1l22 22"/><path d="M14.12 14.12A3 3 0 0 1 9.88 9.88"/></svg>'; + } else { + btn.setAttribute('title', 'Show'); + btn.setAttribute('aria-label', 'Show password'); + btn.innerHTML = '<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M1 12s4-7 11-7 11 7 11 7-4 7-11 7-11-7-11-7z"/><circle cx="12" cy="12" r="3"/></svg>'; + } + }); + }); +} + +export function initializeValidatedSecretField({ + input, + validateBtn, + validationMessage, + toggleBtn, + eyeIcon, + eyeOffIcon, + emptyMessage, + successMessage, + request, + getErrorMessage, + onValid, + onInvalid, + onErrorMessage = 'Validation failed. Please try again.' +}) { + if (!input || !validateBtn || !validationMessage) { + return async () => false; + } + + initializeEyeToggle({ input, toggleBtn, eyeIcon, eyeOffIcon }); + + async function validate() { + const value = input.value.trim(); + if (!value) { + setValidationMessage(validationMessage, emptyMessage, 'error'); + return false; + } + + validateBtn.disabled = true; + validateBtn.classList.add('opacity-50', 'cursor-not-allowed'); + const originalHTML = validateBtn.innerHTML; + validateBtn.innerHTML = LOADING_ICON; + + try { + const data = await request(value); + if (data.valid) { + setValidationMessage(validationMessage, successMessage, 'success'); + if (onValid) onValid(data); + return true; + } + + setValidationMessage(validationMessage, getErrorMessage ? getErrorMessage(data) : 'Invalid API key', 'error'); + if (onInvalid) onInvalid(data); + return false; + } catch (error) { + setValidationMessage(validationMessage, onErrorMessage, 'error'); + return false; + } finally { + validateBtn.disabled = false; + validateBtn.classList.remove('opacity-50', 'cursor-not-allowed'); + validateBtn.innerHTML = originalHTML; + } + } + + validateBtn.addEventListener('click', validate); + input.addEventListener('input', () => clearValidationMessage(validationMessage)); + + return validate; +} diff --git a/app/static/js/modules/form-success.js b/app/static/js/modules/form-success.js new file mode 100644 index 0000000..975cde4 --- /dev/null +++ b/app/static/js/modules/form-success.js @@ -0,0 +1,98 @@ +import { showConfirm, showToast } from './ui.js'; +import { switchSection } from './navigation.js'; + +export function initializeSuccessActions({ emailInput, passwordInput, resetApp, setLoading, showError }) { + const copyBtn = document.getElementById('copyBtn'); + if (copyBtn) { + copyBtn.addEventListener('click', async (e) => { + e.preventDefault(); + e.stopPropagation(); + const urlText = document.getElementById('addonUrl').textContent; + try { + await navigator.clipboard.writeText(urlText); + const originalText = copyBtn.textContent; + copyBtn.textContent = 'Copied!'; + setTimeout(() => { copyBtn.textContent = originalText; }, 2000); + } catch (err) { /* noop */ } + }); + } + + const installDesktopBtn = document.getElementById('installDesktopBtn'); + if (installDesktopBtn) { + installDesktopBtn.addEventListener('click', (e) => { + e.preventDefault(); + e.stopPropagation(); + const url = document.getElementById('addonUrl').textContent; + window.location.href = `stremio://${url.replace(/^https?:\/\//, '')}`; + }); + } + + const installWebBtn = document.getElementById('installWebBtn'); + if (installWebBtn) { + installWebBtn.addEventListener('click', (e) => { + e.preventDefault(); + e.stopPropagation(); + const url = document.getElementById('addonUrl').textContent; + window.open(`https://web.stremio.com/#/addons?addon=${encodeURIComponent(url)}`, '_blank'); + }); + } + + const deleteAccountBtn = document.getElementById('deleteAccountBtn'); + if (deleteAccountBtn) { + deleteAccountBtn.addEventListener('click', async () => { + const confirmed = await showConfirm( + 'Delete Account?', + 'Are you sure you want to delete your settings? This action is irreversible and all your data will be permanently removed.' + ); + + if (!confirmed) return; + + const sAuthKey = (document.getElementById('authKey').value || '').trim(); + const email = emailInput?.value.trim(); + const password = passwordInput?.value; + + if (!sAuthKey && !(email && password)) { + showError('generalError', 'Provide Stremio auth key or email & password to delete your account.'); + switchSection('login'); + return; + } + + setLoading(true); + try { + const payload = { authKey: sAuthKey || undefined, email: email || undefined, password: password || undefined }; + const res = await fetch('/tokens/', { + method: 'DELETE', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(payload) + }); + if (!res.ok) throw new Error((await res.json()).detail || 'Failed to delete'); + showToast('Account deleted successfully.', 'success'); + if (resetApp) resetApp(); + } catch (e) { + showError('generalError', e.message); + } finally { + setLoading(false); + } + }); + } +} + +export function showSuccessSection(url) { + const sections = { + welcome: document.getElementById('sect-welcome'), + login: document.getElementById('sect-login'), + config: document.getElementById('sect-config'), + catalogs: document.getElementById('sect-catalogs'), + install: document.getElementById('sect-install'), + success: document.getElementById('sect-success') + }; + + Object.values(sections).forEach(section => { + if (section) section.classList.add('hidden'); + }); + + if (sections.success) { + sections.success.classList.remove('hidden'); + document.getElementById('addonUrl').textContent = url; + } +} diff --git a/app/static/js/modules/form.js b/app/static/js/modules/form.js index 4c0c76d..09de03e 100644 --- a/app/static/js/modules/form.js +++ b/app/static/js/modules/form.js @@ -1,181 +1,192 @@ // Form Submission and UI Helpers -import { showToast, showConfirm, escapeHtml } from './ui.js'; +import { showToast } from './ui.js'; import { switchSection } from './navigation.js'; +import { + clearValidationMessage, + initializeEyeToggle, + initializePasswordToggleButton, + initializeValidatedSecretField, + setValidationMessage +} from './field-helpers.js'; +import { initializeSuccessActions, showSuccessSection } from './form-success.js'; +import { initializeYearSliderControl } from './year-slider.js'; import { MOVIE_GENRES, SERIES_GENRES } from '../constants.js'; +import { setProviderConnected } from './accounts.js'; + +const YEAR_RANGE_DEFAULTS = window.YEAR_RANGE_DEFAULTS || { min: 1970, max: new Date().getFullYear() }; +const LOADING_ICON = '<svg class="w-5 h-5 animate-spin" fill="none" stroke="currentColor" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>'; // DOM Elements - will be initialized -let configForm = null; let submitBtn = null; let emailInput = null; let passwordInput = null; let languageSelect = null; let movieGenreList = null; let seriesGenreList = null; -let getCatalogs = null; +let appState = null; let resetApp = null; +let validatePosterRatingApiKey = null; +let updateYearSlider = () => {}; -export function initializeForm(domElements, catalogState) { - configForm = domElements.configForm; +export function initializeForm(domElements, state, actions) { submitBtn = domElements.submitBtn; emailInput = domElements.emailInput; passwordInput = domElements.passwordInput; languageSelect = domElements.languageSelect; movieGenreList = domElements.movieGenreList; seriesGenreList = domElements.seriesGenreList; - getCatalogs = catalogState.getCatalogs; - resetApp = catalogState.resetApp; + appState = state; + resetApp = actions.resetApp; initializeFormSubmission(); initializeGenreLists(); initializeLanguageSelect(); initializePasswordToggles(); - initializeSuccessActions(); - initializePosterRatingProvider(); + initializeSuccessHandlers(); + validatePosterRatingApiKey = initializePosterRatingProvider(); initializeTmdb(); initializeSimkl(); initializeGemini(); - initializeYearSlider(); + updateYearSlider = initializeYearSliderControl(); + initializeWatchHistorySource(); } -// Form Submission -async function initializeFormSubmission() { - if (!submitBtn) return; +async function postJson(url, payload) { + const response = await fetch(url, { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, + body: JSON.stringify(payload) + }); - submitBtn.addEventListener("click", async (e) => { - e.preventDefault(); - clearErrors(); + return response.json(); +} - const sAuthKey = (document.getElementById("authKey").value || '').trim(); - const email = emailInput?.value.trim(); - const password = passwordInput?.value; - const language = languageSelect.value; - const popularity = document.getElementById("popularitySelect")?.value || "balanced"; - const yearMin = parseInt(document.getElementById("yearMin")?.value || "1980"); - const yearMax = parseInt(document.getElementById("yearMax")?.value || "2026"); - const sortingOrder = document.getElementById("sortingOrderSelect")?.value || "default"; - const posterRatingProvider = document.getElementById("posterRatingProvider")?.value || ""; - const posterRatingApiKey = document.getElementById("posterRatingApiKey")?.value.trim() || ""; - const excludedMovieGenres = Array.from(document.querySelectorAll('input[name="movie-genre"]:checked')).map(cb => cb.value); - const excludedSeriesGenres = Array.from(document.querySelectorAll('input[name="series-genre"]:checked')).map(cb => cb.value); - const tmdbApiKey = document.getElementById("tmdbApiKey")?.value.trim() || ""; - const simklApiKey = document.getElementById("simklApiKey")?.value.trim() || ""; - const geminiApiKey = document.getElementById("geminiApiKey")?.value.trim() || ""; - - const catalogsToSend = []; - const catalogs = getCatalogs ? getCatalogs() : []; - // Get enabled state from catalog objects (updated by visibility button) - catalogs.forEach(originalCatalog => { - const catalogId = originalCatalog.id; - const enabled = originalCatalog.enabled !== false; - - // Get enabled_movie and enabled_series from toggle buttons - const activeBtn = document.querySelector(`.catalog-type-btn[data-catalog-id="${catalogId}"].bg-white`); - let enabledMovie = true; - let enabledSeries = true; - - if (activeBtn) { - const mode = activeBtn.dataset.mode; - if (mode === 'movie') { - enabledMovie = true; - enabledSeries = false; - } else if (mode === 'series') { - enabledMovie = false; - enabledSeries = true; - } else { - // 'both' or default - enabledMovie = true; - enabledSeries = true; - } - } else { - // Fallback to catalog state - enabledMovie = originalCatalog.enabledMovie !== false; - enabledSeries = originalCatalog.enabledSeries !== false; - } +function getRequestPayload() { + const catalogs = appState ? appState.catalogs : []; + + return { + authKey: (document.getElementById('authKey')?.value || '').trim() || undefined, + email: emailInput?.value.trim() || undefined, + password: passwordInput?.value || undefined, + catalogs: catalogs.map(catalog => ({ + id: catalog.id, + name: catalog.name, + enabled: catalog.enabled !== false, + enabled_movie: catalog.enabledMovie !== false, + enabled_series: catalog.enabledSeries !== false, + display_at_home: catalog.display_at_home !== false, + shuffle: catalog.shuffle === true + })), + language: languageSelect?.value || 'english', + year_min: parseInt(document.getElementById('yearMin')?.value || String(YEAR_RANGE_DEFAULTS.min), 10), + year_max: parseInt(document.getElementById('yearMax')?.value || String(YEAR_RANGE_DEFAULTS.max), 10), + popularity: document.getElementById('popularitySelect')?.value || 'balanced', + sorting_order: document.getElementById('sortingOrderSelect')?.value || 'default', + poster_rating_provider: document.getElementById('posterRatingProvider')?.value || '', + poster_rating_api_key: document.getElementById('posterRatingApiKey')?.value.trim() || '', + tmdb_api_key: document.getElementById('tmdbApiKey')?.value.trim() || '', + simkl_api_key: document.getElementById('simklApiKey')?.value.trim() || '', + gemini_api_key: document.getElementById('geminiApiKey')?.value.trim() || '', + excluded_movie_genres: Array.from(document.querySelectorAll('input[name="movie-genre"]:checked')).map(cb => cb.value), + excluded_series_genres: Array.from(document.querySelectorAll('input[name="series-genre"]:checked')).map(cb => cb.value), + watch_history_source: document.getElementById('watchHistorySource')?.value || 'stremio', + }; +} - catalogsToSend.push({ - id: catalogId, - name: originalCatalog.name, - enabled: enabled, - enabled_movie: enabledMovie, - enabled_series: enabledSeries, - display_at_home: originalCatalog.display_at_home !== false, // Default to true if not set - shuffle: originalCatalog.shuffle === true, // Default to false if not set - }); - }); +function buildTokenPayload(formData) { + let posterRating; + if (formData.poster_rating_provider && formData.poster_rating_api_key) { + posterRating = { + provider: formData.poster_rating_provider, + api_key: formData.poster_rating_api_key + }; + } - // Validation - if (!sAuthKey && !(email && password)) { - showError("generalError", "Please login with Stremio or enter email & password."); - switchSection('login'); - return; + return { + authKey: formData.authKey, + email: formData.email, + password: formData.password, + catalogs: formData.catalogs, + language: formData.language, + year_min: formData.year_min, + year_max: formData.year_max, + popularity: formData.popularity, + sorting_order: formData.sorting_order, + poster_rating: posterRating || null, + tmdb_api_key: formData.tmdb_api_key || undefined, + simkl_api_key: formData.simkl_api_key, + gemini_api_key: formData.gemini_api_key, + excluded_movie_genres: formData.excluded_movie_genres, + excluded_series_genres: formData.excluded_series_genres, + watch_history_source: formData.watch_history_source, + trakt_access_token: window._watchlyOAuth?.trakt?.access_token || undefined, + trakt_refresh_token: window._watchlyOAuth?.trakt?.refresh_token || undefined, + trakt_token_expires_at: window._watchlyOAuth?.trakt?.expires_at || undefined, + simkl_access_token: window._watchlyOAuth?.simkl?.access_token || undefined, + }; +} + +function validateFormData(formData) { + if (!formData.authKey && !(formData.email && formData.password)) { + showError('generalError', 'Please login with Stremio or enter email & password.'); + switchSection('login'); + return false; + } + + if (!formData.tmdb_api_key) { + showError('generalError', 'TMDB API key is required.'); + const tmdbInput = document.getElementById('tmdbApiKey'); + if (tmdbInput) { + tmdbInput.focus(); + tmdbInput.scrollIntoView({ behavior: 'smooth', block: 'center' }); } + return false; + } - if (!tmdbApiKey) { - showError("generalError", "TMDB API key is required."); - const tmdbInput = document.getElementById("tmdbApiKey"); - if (tmdbInput) { - tmdbInput.focus(); - tmdbInput.scrollIntoView({ behavior: "smooth", block: "center" }); - } + return true; +} + +// Form Submission +function initializeFormSubmission() { + if (!submitBtn) return; + + submitBtn.addEventListener('click', async (e) => { + e.preventDefault(); + clearErrors(); + + const formData = getRequestPayload(); + if (!validateFormData(formData)) { return; } - // Validate poster rating API key if provided - if (posterRatingProvider && posterRatingApiKey) { - if (window.validatePosterRatingApiKey) { - const isValid = await window.validatePosterRatingApiKey(); - if (!isValid) { - return; - } + if (formData.poster_rating_provider && formData.poster_rating_api_key && validatePosterRatingApiKey) { + const isValid = await validatePosterRatingApiKey(); + if (!isValid) { + return; } } setLoading(true); try { - // Build poster_rating payload - let posterRating = null; - if (posterRatingProvider && posterRatingApiKey) { - posterRating = { - provider: posterRatingProvider, - api_key: posterRatingApiKey - }; - } - - const payload = { - authKey: sAuthKey || undefined, - email: email || undefined, - password: password || undefined, - catalogs: catalogsToSend, - language: language, - year_min: yearMin, - year_max: yearMax, - popularity: popularity, - sorting_order: sortingOrder, - poster_rating: posterRating, - tmdb_api_key: tmdbApiKey || undefined, - simkl_api_key: simklApiKey, - gemini_api_key: geminiApiKey, - excluded_movie_genres: excludedMovieGenres, - excluded_series_genres: excludedSeriesGenres - }; - - const response = await fetch("/tokens/", { - method: "POST", - headers: { "Content-Type": "application/json" }, + const payload = buildTokenPayload(formData); + const response = await fetch('/tokens/', { + method: 'POST', + headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(payload) }); if (!response.ok) { const errorData = await response.json(); - throw new Error(errorData.detail || "Failed to generate manifest URL"); + throw new Error(errorData.detail || 'Failed to generate manifest URL'); } + const data = await response.json(); showSuccess(data.manifestUrl); } catch (error) { - console.error("Error:", error); - showError("generalError", error.message); + console.error('Error:', error); + showError('generalError', error.message); } finally { setLoading(false); } @@ -190,6 +201,7 @@ function initializeGenreLists() { function renderGenreList(container, genres, namePrefix) { if (!container) return; + container.innerHTML = genres.map(genre => ` <label class="flex items-center gap-3 p-2 rounded-lg hover:bg-white/5 cursor-pointer transition group"> <div class="relative flex items-center"> @@ -211,473 +223,200 @@ function initializeLanguageSelect() { // Poster Rating Provider function initializePosterRatingProvider() { - const providerSelect = document.getElementById("posterRatingProvider"); - const apiKeyContainer = document.getElementById("posterRatingApiKeyContainer"); - const apiKeyInput = document.getElementById("posterRatingApiKey"); - const helpContainer = document.getElementById("posterRatingHelp"); - const helpText = document.getElementById("posterRatingHelpText"); - const validateBtn = document.getElementById("posterRatingApiKeyValidate"); - const toggleBtn = document.getElementById("posterRatingApiKeyToggle"); - const eyeIcon = document.getElementById("posterRatingApiKeyEye"); - const eyeOffIcon = document.getElementById("posterRatingApiKeyEyeOff"); - const validationMessage = document.getElementById("posterRatingValidationMessage"); - - if (!providerSelect || !apiKeyContainer || !apiKeyInput || !helpContainer || !helpText) return; + const providerSelect = document.getElementById('posterRatingProvider'); + const apiKeyContainer = document.getElementById('posterRatingApiKeyContainer'); + const apiKeyInput = document.getElementById('posterRatingApiKey'); + const helpContainer = document.getElementById('posterRatingHelp'); + const helpText = document.getElementById('posterRatingHelpText'); + const validateBtn = document.getElementById('posterRatingApiKeyValidate'); + const toggleBtn = document.getElementById('posterRatingApiKeyToggle'); + const eyeIcon = document.getElementById('posterRatingApiKeyEye'); + const eyeOffIcon = document.getElementById('posterRatingApiKeyEyeOff'); + const validationMessage = document.getElementById('posterRatingValidationMessage'); + + if (!providerSelect || !apiKeyContainer || !apiKeyInput || !helpContainer || !helpText) { + return null; + } const providerInfo = { - "rpdb": { - name: "RPDB (RatingPosterDB)", - url: "https://ratingposterdb.com", - description: "Enable ratings on posters via RatingPosterDB" + rpdb: { + name: 'RPDB (RatingPosterDB)', + url: 'https://ratingposterdb.com', + description: 'Enable ratings on posters via RatingPosterDB' }, - "top_posters": { - name: "Top Posters", - url: "https://api.top-streaming.stream/", - description: "Enable ratings on posters via Top Posters" + top_posters: { + name: 'Top Posters', + url: 'https://api.top-streaming.stream/', + description: 'Enable ratings on posters via Top Posters' } }; let isValidated = false; - // Eye toggle functionality - if (toggleBtn && eyeIcon && eyeOffIcon) { - toggleBtn.addEventListener("click", () => { - const isPassword = apiKeyInput.type === "password"; - apiKeyInput.type = isPassword ? "text" : "password"; - eyeIcon.classList.toggle("hidden", !isPassword); - eyeOffIcon.classList.toggle("hidden", isPassword); - }); + initializeEyeToggle({ input: apiKeyInput, toggleBtn, eyeIcon, eyeOffIcon }); + + function resetValidation() { + isValidated = false; + clearValidationMessage(validationMessage); + } + + function updateUI() { + const selectedProvider = providerSelect.value; + const info = providerInfo[selectedProvider]; + + if (info) { + apiKeyContainer.style.display = 'block'; + helpContainer.style.display = 'block'; + helpText.innerHTML = `${info.description}. Get your API key from <a href="${info.url}" target="_blank" class="text-slate-300 hover:text-white underline">${info.name}</a>.`; + resetValidation(); + return; + } + + apiKeyContainer.style.display = 'none'; + helpContainer.style.display = 'none'; + apiKeyInput.value = ''; + resetValidation(); } - // Validation function async function validateApiKey() { const selectedProvider = providerSelect.value; const apiKey = apiKeyInput.value.trim(); if (!selectedProvider || !apiKey) { - showValidationMessage("Please select a provider and enter an API key", "error"); + setValidationMessage(validationMessage, 'Please select a provider and enter an API key', 'error'); return false; } - if (!validateBtn) return false; + if (!validateBtn) { + return false; + } - // Show loading state validateBtn.disabled = true; - validateBtn.classList.add("opacity-50", "cursor-not-allowed"); + validateBtn.classList.add('opacity-50', 'cursor-not-allowed'); const originalHTML = validateBtn.innerHTML; - validateBtn.innerHTML = '<svg class="w-5 h-5 animate-spin" fill="none" stroke="currentColor" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>'; + validateBtn.innerHTML = LOADING_ICON; try { - const response = await fetch("/poster-rating/validate", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ provider: selectedProvider, api_key: apiKey }) + const data = await postJson('/poster-rating/validate', { + provider: selectedProvider, + api_key: apiKey }); - const data = await response.json(); - if (data.valid) { - showValidationMessage("API key is valid ✓", "success"); + setValidationMessage(validationMessage, 'API key is valid ✓', 'success'); isValidated = true; return true; - } else { - showValidationMessage(data.message || "Invalid API key", "error"); - apiKeyInput.value = ""; // Clear invalid key - isValidated = false; - return false; } + + setValidationMessage(validationMessage, data.message || 'Invalid API key', 'error'); + apiKeyInput.value = ''; + isValidated = false; + return false; } catch (error) { - showValidationMessage("Validation failed. Please try again.", "error"); + setValidationMessage(validationMessage, 'Validation failed. Please try again.', 'error'); isValidated = false; return false; } finally { validateBtn.disabled = false; - validateBtn.classList.remove("opacity-50", "cursor-not-allowed"); + validateBtn.classList.remove('opacity-50', 'cursor-not-allowed'); validateBtn.innerHTML = originalHTML; } } - // Show validation message - function showValidationMessage(message, type) { - if (!validationMessage) return; - validationMessage.textContent = message; - validationMessage.className = `mt-2 text-xs ${type === "success" ? "text-green-400" : "text-red-400"}`; - validationMessage.classList.remove("hidden"); - } - - // Clear validation message - function clearValidationMessage() { - if (validationMessage) { - validationMessage.classList.add("hidden"); - } - } - - // Validate button click if (validateBtn) { - validateBtn.addEventListener("click", validateApiKey); + validateBtn.addEventListener('click', validateApiKey); } - // Clear validation when API key changes - apiKeyInput.addEventListener("input", () => { - isValidated = false; - clearValidationMessage(); - }); - - function updateUI() { - const selectedProvider = providerSelect.value; + apiKeyInput.addEventListener('input', resetValidation); + providerSelect.addEventListener('change', updateUI); + updateUI(); - if (selectedProvider && providerInfo[selectedProvider]) { - const info = providerInfo[selectedProvider]; - apiKeyContainer.style.display = "block"; - helpContainer.style.display = "block"; - helpText.innerHTML = `${info.description}. Get your API key from <a href="${info.url}" target="_blank" class="text-slate-300 hover:text-white underline">${info.name}</a>.`; - // Don't clear the API key when switching providers - just reset validation - isValidated = false; - clearValidationMessage(); - } else { - // Only clear when provider is set to "None" - apiKeyContainer.style.display = "none"; - helpContainer.style.display = "none"; - apiKeyInput.value = ""; - isValidated = false; - clearValidationMessage(); + return async () => { + if (isValidated) { + return true; } - } - - // Handle provider change - preserve API key value, just reset validation - providerSelect.addEventListener("change", () => { - isValidated = false; - clearValidationMessage(); - updateUI(); - }); - updateUI(); // Initialize on load - - // Export validate function for form submission - window.validatePosterRatingApiKey = validateApiKey; + return validateApiKey(); + }; } // TMDB API Key (Required) function initializeTmdb() { - const apiKeyInput = document.getElementById("tmdbApiKey"); - const validateBtn = document.getElementById("tmdbApiKeyValidate"); - const toggleBtn = document.getElementById("tmdbApiKeyToggle"); - const eyeIcon = document.getElementById("tmdbApiKeyEye"); - const eyeOffIcon = document.getElementById("tmdbApiKeyEyeOff"); - const validationMessage = document.getElementById("tmdbValidationMessage"); - - if (!apiKeyInput || !validationMessage) return; - - if (toggleBtn && eyeIcon && eyeOffIcon) { - toggleBtn.addEventListener("click", () => { - const isPassword = apiKeyInput.type === "password"; - apiKeyInput.type = isPassword ? "text" : "password"; - eyeIcon.classList.toggle("hidden", !isPassword); - eyeOffIcon.classList.toggle("hidden", isPassword); - }); - } - - function showTmdbValidationMessage(message, type) { - validationMessage.textContent = message; - validationMessage.className = `mt-2 text-xs ${type === "success" ? "text-green-400" : "text-red-400"}`; - validationMessage.classList.remove("hidden"); - } - - if (validateBtn) { - validateBtn.addEventListener("click", async () => { - const apiKey = apiKeyInput.value.trim(); - if (!apiKey) { - showTmdbValidationMessage("Please enter a TMDB API key", "error"); - return; - } - validateBtn.disabled = true; - validateBtn.classList.add("opacity-50", "cursor-not-allowed"); - const originalHTML = validateBtn.innerHTML; - validateBtn.innerHTML = '<svg class="w-5 h-5 animate-spin" fill="none" stroke="currentColor" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>'; - try { - const response = await fetch("/tmdb/validation", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ api_key: apiKey }) - }); - const data = await response.json(); - if (data.valid) { - showTmdbValidationMessage("TMDB API key is valid ✓", "success"); - } else { - showTmdbValidationMessage(data.message || "Invalid TMDB API key", "error"); - } - } catch (error) { - showTmdbValidationMessage("Validation failed. Please try again.", "error"); - } finally { - validateBtn.disabled = false; - validateBtn.classList.remove("opacity-50", "cursor-not-allowed"); - validateBtn.innerHTML = originalHTML; - } - }); - } - - apiKeyInput.addEventListener("input", () => validationMessage.classList.add("hidden")); + initializeValidatedSecretField({ + input: document.getElementById('tmdbApiKey'), + validateBtn: document.getElementById('tmdbApiKeyValidate'), + validationMessage: document.getElementById('tmdbValidationMessage'), + toggleBtn: document.getElementById('tmdbApiKeyToggle'), + eyeIcon: document.getElementById('tmdbApiKeyEye'), + eyeOffIcon: document.getElementById('tmdbApiKeyEyeOff'), + emptyMessage: 'Please enter a TMDB API key', + successMessage: 'TMDB API key is valid ✓', + request: (apiKey) => postJson('/tmdb/validation', { api_key: apiKey }), + getErrorMessage: (data) => data.message || 'Invalid TMDB API key' + }); } // Simkl Integration function initializeSimkl() { - const apiKeyInput = document.getElementById("simklApiKey"); - const validateBtn = document.getElementById("simklApiKeyValidate"); - const toggleBtn = document.getElementById("simklApiKeyToggle"); - const eyeIcon = document.getElementById("simklApiKeyEye"); - const eyeOffIcon = document.getElementById("simklApiKeyEyeOff"); - const validationMessage = document.getElementById("simklValidationMessage"); - - if (!apiKeyInput || !validateBtn || !validationMessage) return; - - // Eye toggle functionality - if (toggleBtn && eyeIcon && eyeOffIcon) { - toggleBtn.addEventListener("click", () => { - const isPassword = apiKeyInput.type === "password"; - apiKeyInput.type = isPassword ? "text" : "password"; - eyeIcon.classList.toggle("hidden", !isPassword); - eyeOffIcon.classList.toggle("hidden", isPassword); - }); - } - - // Validation function - async function validateSimklKey() { - const apiKey = apiKeyInput.value.trim(); - - if (!apiKey) { - showSimklValidationMessage("Please enter a Simkl API key", "error"); - return false; - } - - // Show loading state - validateBtn.disabled = true; - validateBtn.classList.add("opacity-50", "cursor-not-allowed"); - const originalHTML = validateBtn.innerHTML; - validateBtn.innerHTML = '<svg class="w-5 h-5 animate-spin" fill="none" stroke="currentColor" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>'; - - try { - const response = await fetch("/simkl/validation", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ api_key: apiKey }) - }); - - const data = await response.json(); - - if (data.valid) { - showSimklValidationMessage("Simkl API key is valid ✓", "success"); - return true; - } else { - showSimklValidationMessage(data.message || "Invalid Simkl API key", "error"); - return false; - } - } catch (error) { - showSimklValidationMessage("Validation failed. Please try again.", "error"); - return false; - } finally { - validateBtn.disabled = false; - validateBtn.classList.remove("opacity-50", "cursor-not-allowed"); - validateBtn.innerHTML = originalHTML; - } - } - - function showSimklValidationMessage(message, type) { - validationMessage.textContent = message; - validationMessage.className = `mt-2 text-xs ${type === "success" ? "text-green-400" : "text-red-400"}`; - validationMessage.classList.remove("hidden"); - } - - validateBtn.addEventListener("click", validateSimklKey); - - apiKeyInput.addEventListener("input", () => { - validationMessage.classList.add("hidden"); + initializeValidatedSecretField({ + input: document.getElementById('simklApiKey'), + validateBtn: document.getElementById('simklApiKeyValidate'), + validationMessage: document.getElementById('simklValidationMessage'), + toggleBtn: document.getElementById('simklApiKeyToggle'), + eyeIcon: document.getElementById('simklApiKeyEye'), + eyeOffIcon: document.getElementById('simklApiKeyEyeOff'), + emptyMessage: 'Please enter a Simkl API key', + successMessage: 'Simkl API key is valid ✓', + request: (apiKey) => postJson('/simkl/validation', { api_key: apiKey }), + getErrorMessage: (data) => data.message || 'Invalid Simkl API key' }); } // Gemini AI Integration function initializeGemini() { - const apiKeyInput = document.getElementById("geminiApiKey"); - const validateBtn = document.getElementById("geminiApiKeyValidate"); - const toggleBtn = document.getElementById("geminiApiKeyToggle"); - const eyeIcon = document.getElementById("geminiApiKeyEye"); - const eyeOffIcon = document.getElementById("geminiApiKeyEyeOff"); - const validationMessage = document.getElementById("geminiValidationMessage"); - - if (!apiKeyInput || !validateBtn || !validationMessage) return; - - // Eye toggle functionality - if (toggleBtn && eyeIcon && eyeOffIcon) { - toggleBtn.addEventListener("click", () => { - const isPassword = apiKeyInput.type === "password"; - apiKeyInput.type = isPassword ? "text" : "password"; - eyeIcon.classList.toggle("hidden", !isPassword); - eyeOffIcon.classList.toggle("hidden", isPassword); - }); - } - - // Validation function - async function validateGeminiKey() { - const apiKey = apiKeyInput.value.trim(); - - if (!apiKey) { - showGeminiValidationMessage("Please enter a Gemini API key", "error"); - return false; - } - - // Show loading state - validateBtn.disabled = true; - validateBtn.classList.add("opacity-50", "cursor-not-allowed"); - const originalHTML = validateBtn.innerHTML; - validateBtn.innerHTML = '<svg class="w-5 h-5 animate-spin" fill="none" stroke="currentColor" viewBox="0 0 24 24"><circle class="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" stroke-width="4"></circle><path class="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path></svg>'; - - try { - const response = await fetch("/gemini/validation", { - method: "POST", - headers: { "Content-Type": "application/json" }, - body: JSON.stringify({ api_key: apiKey }) - }); - - const data = await response.json(); - - if (data.valid) { - showGeminiValidationMessage("Gemini API key is valid ✓", "success"); - return true; - } else { - showGeminiValidationMessage(data.message || "Invalid Gemini API key", "error"); - return false; - } - } catch (error) { - showGeminiValidationMessage("Validation failed. Please try again.", "error"); - return false; - } finally { - validateBtn.disabled = false; - validateBtn.classList.remove("opacity-50", "cursor-not-allowed"); - validateBtn.innerHTML = originalHTML; - } - } - - function showGeminiValidationMessage(message, type) { - validationMessage.textContent = message; - validationMessage.className = `mt-2 text-xs ${type === "success" ? "text-green-400" : "text-red-400"}`; - validationMessage.classList.remove("hidden"); - } - - validateBtn.addEventListener("click", validateGeminiKey); - - apiKeyInput.addEventListener("input", () => { - validationMessage.classList.add("hidden"); + initializeValidatedSecretField({ + input: document.getElementById('geminiApiKey'), + validateBtn: document.getElementById('geminiApiKeyValidate'), + validationMessage: document.getElementById('geminiValidationMessage'), + toggleBtn: document.getElementById('geminiApiKeyToggle'), + eyeIcon: document.getElementById('geminiApiKeyEye'), + eyeOffIcon: document.getElementById('geminiApiKeyEyeOff'), + emptyMessage: 'Please enter a Gemini API key', + successMessage: 'Gemini API key is valid ✓', + request: (apiKey) => postJson('/gemini/validation', { api_key: apiKey }), + getErrorMessage: (data) => data.message || 'Invalid Gemini API key' }); } -// Password Toggles function initializePasswordToggles() { - document.querySelectorAll('.toggle-btn').forEach(btn => { - btn.addEventListener('click', () => { - const targetId = btn.getAttribute('data-target'); - const input = document.getElementById(targetId); - if (!input) return; - const isHidden = input.type === 'password'; - input.type = isHidden ? 'text' : 'password'; - // Swap icon and labels - if (isHidden) { - // Now visible: show eye-off icon - btn.setAttribute('title', 'Hide'); - btn.setAttribute('aria-label', 'Hide password'); - btn.innerHTML = '<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M17.94 17.94A10.94 10.94 0 0 1 12 20c-7 0-11-8-11-8a21.77 21.77 0 0 1 5.06-6.17M9.9 4.24A10.94 10.94 0 0 1 12 4c7 0 11 8 11 8a21.8 21.8 0 0 1-3.22 4.31"/><path d="M1 1l22 22"/><path d="M14.12 14.12A3 3 0 0 1 9.88 9.88"/></svg>'; - } else { - // Now hidden: show eye icon - btn.setAttribute('title', 'Show'); - btn.setAttribute('aria-label', 'Show password'); - btn.innerHTML = '<svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" stroke-width="2" stroke-linecap="round" stroke-linejoin="round"><path d="M1 12s4-7 11-7 11 7 11 7-4 7-11 7-11-7-11-7z"/><circle cx="12" cy="12" r="3"/></svg>'; - } - }); - }); + initializePasswordToggleButton(); } -// Delete & Success Helpers -function initializeSuccessActions() { - const copyBtn = document.getElementById('copyBtn'); - if (copyBtn) { - copyBtn.addEventListener('click', async (e) => { - e.preventDefault(); - e.stopPropagation(); - const urlText = document.getElementById('addonUrl').textContent; - try { - await navigator.clipboard.writeText(urlText); - const originalText = copyBtn.innerHTML; - copyBtn.innerHTML = 'Copied!'; - setTimeout(() => { copyBtn.innerHTML = originalText; }, 2000); - } catch (err) { } - }); - } - - const installDesktopBtn = document.getElementById('installDesktopBtn'); - if (installDesktopBtn) { - installDesktopBtn.addEventListener('click', (e) => { - e.preventDefault(); - e.stopPropagation(); - const url = document.getElementById('addonUrl').textContent; - window.location.href = `stremio://${url.replace(/^https?:\/\//, '')}`; - }); - } - const installWebBtn = document.getElementById('installWebBtn'); - if (installWebBtn) { - installWebBtn.addEventListener('click', (e) => { - e.preventDefault(); - e.stopPropagation(); - const url = document.getElementById('addonUrl').textContent; - window.open(`https://web.stremio.com/#/addons?addon=${encodeURIComponent(url)}`, '_blank'); - }); - } - - const deleteAccountBtn = document.getElementById('deleteAccountBtn'); - if (deleteAccountBtn) { - deleteAccountBtn.addEventListener('click', async () => { - const confirmed = await showConfirm( - 'Delete Account?', - 'Are you sure you want to delete your settings? This action is irreversible and all your data will be permanently removed.' - ); - - if (!confirmed) return; - - const sAuthKey = (document.getElementById("authKey").value || '').trim(); - const email = emailInput?.value.trim(); - const password = passwordInput?.value; - - if (!sAuthKey && !(email && password)) { - showError('generalError', "Provide Stremio auth key or email & password to delete your account."); - switchSection('login'); - return; - } - - setLoading(true); - try { - const payload = { authKey: sAuthKey || undefined, email: email || undefined, password: password || undefined }; - const res = await fetch('/tokens/', { method: 'DELETE', headers: { 'Content-Type': 'application/json' }, body: JSON.stringify(payload) }); - if (!res.ok) throw new Error((await res.json()).detail || 'Failed to delete'); - showToast('Account deleted successfully.', 'success'); - if (resetApp) resetApp(); - } catch (e) { - showError('generalError', e.message); - } finally { - setLoading(false); - } - }); - } +function initializeSuccessHandlers() { + initializeSuccessActions({ + emailInput, + passwordInput, + resetApp, + setLoading, + showError + }); } function setLoading(loading) { if (!submitBtn) return; + const btnText = submitBtn.querySelector('.btn-text'); const loader = submitBtn.querySelector('.loader'); submitBtn.disabled = loading; + if (loading) { if (btnText) btnText.classList.add('hidden'); if (loader) loader.classList.remove('hidden'); - } else { - if (btnText) btnText.classList.remove('hidden'); - if (loader) loader.classList.add('hidden'); + return; } + + if (btnText) btnText.classList.remove('hidden'); + if (loader) loader.classList.add('hidden'); } function showError(target, message) { @@ -686,94 +425,114 @@ function showError(target, message) { if (errEl) { errEl.querySelector('.message-content').textContent = message; errEl.classList.remove('hidden'); - } else { showToast(message, 'error'); } - } else if (target === 'stremioAuthSection') { - showToast(message, 'error'); - } else { - const el = document.getElementById(target); - if (el) { - el.classList.add('border-red-500'); - el.focus(); + } else { + showToast(message, 'error'); } + return; + } + + if (target === 'stremioAuthSection') { + showToast(message, 'error'); + return; } + + const element = document.getElementById(target); + if (!element) return; + + element.classList.add('border-red-500'); + element.focus(); } export function clearErrors() { const errEl = document.getElementById('errorMessage'); - if (errEl) errEl.classList.add('hidden'); - document.querySelectorAll('.border-red-500').forEach(e => e.classList.remove('border-red-500')); -} - -function showSuccess(url) { - // Hide form entirely by hiding the active section - const sections = { - welcome: document.getElementById('sect-welcome'), - login: document.getElementById('sect-login'), - config: document.getElementById('sect-config'), - catalogs: document.getElementById('sect-catalogs'), - install: document.getElementById('sect-install'), - success: document.getElementById('sect-success') - }; - Object.values(sections).forEach(s => { if (s) s.classList.add('hidden') }); - - // Show Success Section - if (sections.success) { - sections.success.classList.remove('hidden'); - document.getElementById('addonUrl').textContent = url; + if (errEl) { + errEl.classList.add('hidden'); } -} -// Year Slider Logic -function initializeYearSlider() { - const yearMin = document.getElementById('yearMin'); - const yearMax = document.getElementById('yearMax'); - const yearMinLabel = document.getElementById('yearMinLabel'); - const yearMaxLabel = document.getElementById('yearMaxLabel'); - const track = document.getElementById('yearSliderTrack'); + document.querySelectorAll('.border-red-500').forEach(element => { + element.classList.remove('border-red-500'); + }); +} - if (!yearMin || !yearMax || !yearMinLabel || !yearMaxLabel || !track) return; +export function refreshYearSlider() { + updateYearSlider(); +} - function updateSlider() { - const minVal = parseInt(yearMin.value); - const maxVal = parseInt(yearMax.value); +function showSuccess(url) { + showSuccessSection(url); +} - if (minVal > maxVal) { - // Prevent crossing: if min > max, snap them - // This is handled by input listeners to avoid jerky movement +// Watch History Source + OAuth +function initializeWatchHistorySource() { + const traktLoginBtn = document.getElementById('traktLoginBtn'); + const traktStatus = document.getElementById('traktStatus'); + const traktLogoutBtn = document.getElementById('traktLogoutBtn'); + const simklLoginBtn = document.getElementById('simklLoginBtn'); + const simklSyncStatus = document.getElementById('simklSyncStatus'); + const simklSyncLogoutBtn = document.getElementById('simklSyncLogoutBtn'); + + window._watchlyOAuth = window._watchlyOAuth || {}; + + window.addEventListener('message', (event) => { + const data = event.data; + if (!data || !data.provider || !data.tokens) return; + + if (data.provider === 'trakt') { + window._watchlyOAuth.trakt = data.tokens; + if (traktStatus) { + traktStatus.textContent = `Connected as ${data.username || 'Unknown'}`; + traktStatus.classList.remove('text-slate-500'); + traktStatus.classList.add('text-green-400'); + } + if (traktLogoutBtn) traktLogoutBtn.classList.remove('hidden'); + setProviderConnected('trakt', true); + } else if (data.provider === 'simkl') { + window._watchlyOAuth.simkl = data.tokens; + if (simklSyncStatus) { + simklSyncStatus.textContent = `Connected as ${data.username || 'Unknown'}`; + simklSyncStatus.classList.remove('text-slate-500'); + simklSyncStatus.classList.add('text-green-400'); + } + if (simklSyncLogoutBtn) simklSyncLogoutBtn.classList.remove('hidden'); + setProviderConnected('simkl', true); } + }); - yearMinLabel.textContent = minVal; - yearMaxLabel.textContent = maxVal; - - const range = yearMin.max - yearMin.min; - const left = ((minVal - yearMin.min) / range) * 100; - const right = ((yearMin.max - maxVal) / range) * 100; - - track.style.left = left + '%'; - track.style.right = right + '%'; + if (traktLoginBtn) { + traktLoginBtn.addEventListener('click', () => { + window.open('/auth/trakt', '_blank', 'width=600,height=700'); + }); } - yearMin.addEventListener('input', () => { - if (parseInt(yearMin.value) > parseInt(yearMax.value)) { - yearMin.value = yearMax.value; - } - yearMin.classList.add('year-slider-active'); - yearMax.classList.remove('year-slider-active'); - updateSlider(); - }); - - yearMax.addEventListener('input', () => { - if (parseInt(yearMax.value) < parseInt(yearMin.value)) { - yearMax.value = yearMin.value; - } - yearMax.classList.add('year-slider-active'); - yearMin.classList.remove('year-slider-active'); - updateSlider(); - }); + if (simklLoginBtn) { + simklLoginBtn.addEventListener('click', () => { + window.open('/auth/simkl', '_blank', 'width=600,height=700'); + }); + } - // Initial update - updateSlider(); + if (traktLogoutBtn) { + traktLogoutBtn.addEventListener('click', () => { + delete window._watchlyOAuth.trakt; + if (traktStatus) { + traktStatus.textContent = 'Not connected'; + traktStatus.classList.remove('text-green-400'); + traktStatus.classList.add('text-slate-500'); + } + traktLogoutBtn.classList.add('hidden'); + setProviderConnected('trakt', false); + }); + } - // Export update function for external population - window.updateYearSlider = updateSlider; + if (simklSyncLogoutBtn) { + simklSyncLogoutBtn.addEventListener('click', () => { + delete window._watchlyOAuth.simkl; + if (simklSyncStatus) { + simklSyncStatus.textContent = 'Not connected'; + simklSyncStatus.classList.remove('text-green-400'); + simklSyncStatus.classList.add('text-slate-500'); + } + simklSyncLogoutBtn.classList.add('hidden'); + setProviderConnected('simkl', false); + }); + } } diff --git a/app/static/js/modules/navigation.js b/app/static/js/modules/navigation.js index 96e161e..fbc26eb 100644 --- a/app/static/js/modules/navigation.js +++ b/app/static/js/modules/navigation.js @@ -4,11 +4,13 @@ let navItems = {}; let sections = {}; let mainEl = null; +let appState = null; -export function initializeNavigation(domElements) { +export function initializeNavigation(domElements, state) { navItems = domElements.navItems; sections = domElements.sections; mainEl = domElements.mainEl; + appState = state; Object.keys(navItems).forEach(key => { if (navItems[key]) { @@ -81,6 +83,10 @@ export function initializeMobileNav() { } export function switchSection(sectionKey) { + if (appState) { + appState.ui.currentSection = sectionKey; + } + // Hide all sections Object.values(sections).forEach(el => { if (el) el.classList.add('hidden'); diff --git a/app/static/js/modules/year-slider.js b/app/static/js/modules/year-slider.js new file mode 100644 index 0000000..0e985d3 --- /dev/null +++ b/app/static/js/modules/year-slider.js @@ -0,0 +1,47 @@ +export function initializeYearSliderControl() { + const yearMin = document.getElementById('yearMin'); + const yearMax = document.getElementById('yearMax'); + const yearMinLabel = document.getElementById('yearMinLabel'); + const yearMaxLabel = document.getElementById('yearMaxLabel'); + const track = document.getElementById('yearSliderTrack'); + + if (!yearMin || !yearMax || !yearMinLabel || !yearMaxLabel || !track) { + return () => {}; + } + + function updateSlider() { + const minVal = parseInt(yearMin.value); + const maxVal = parseInt(yearMax.value); + + yearMinLabel.textContent = minVal; + yearMaxLabel.textContent = maxVal; + + const range = yearMin.max - yearMin.min; + const left = ((minVal - yearMin.min) / range) * 100; + const right = ((yearMin.max - maxVal) / range) * 100; + + track.style.left = left + '%'; + track.style.right = right + '%'; + } + + yearMin.addEventListener('input', () => { + if (parseInt(yearMin.value) > parseInt(yearMax.value)) { + yearMin.value = yearMax.value; + } + yearMin.classList.add('year-slider-active'); + yearMax.classList.remove('year-slider-active'); + updateSlider(); + }); + + yearMax.addEventListener('input', () => { + if (parseInt(yearMax.value) < parseInt(yearMin.value)) { + yearMax.value = yearMin.value; + } + yearMax.classList.add('year-slider-active'); + yearMin.classList.remove('year-slider-active'); + updateSlider(); + }); + + updateSlider(); + return updateSlider; +} diff --git a/app/static/js/state.js b/app/static/js/state.js new file mode 100644 index 0000000..130196d --- /dev/null +++ b/app/static/js/state.js @@ -0,0 +1,27 @@ +import { defaultCatalogs } from './constants.js'; + +export function cloneDefaultCatalogs() { + return JSON.parse(JSON.stringify(defaultCatalogs)); +} + +export function createAppState() { + return { + auth: { + loggedIn: false, + authKey: '', + userDisplay: null + }, + ui: { + currentSection: 'welcome' + }, + catalogs: cloneDefaultCatalogs() + }; +} + +export function resetAppState(state) { + state.auth.loggedIn = false; + state.auth.authKey = ''; + state.auth.userDisplay = null; + state.ui.currentSection = 'welcome'; + state.catalogs = cloneDefaultCatalogs(); +} diff --git a/app/templates/base.html b/app/templates/base.html index d7ff376..85e32c6 100644 --- a/app/templates/base.html +++ b/app/templates/base.html @@ -292,6 +292,7 @@ <script> // Default catalog configurations from backend window.DEFAULT_CATALOGS = {{ default_catalogs | tojson }}; + window.YEAR_RANGE_DEFAULTS = {{ year_range_defaults | tojson }}; // Genre constants from backend window.MOVIE_GENRES = {{ movie_genres | tojson }}; diff --git a/app/templates/components/section_config.html b/app/templates/components/section_config.html index 1ed9016..9a10916 100644 --- a/app/templates/components/section_config.html +++ b/app/templates/components/section_config.html @@ -92,7 +92,7 @@ <h2 class="text-3xl font-bold text-white mb-2">Preferences</h2> <div class="flex items-center justify-between"> <label class="block text-sm font-medium text-slate-400 uppercase tracking-wider">Release Year Range</label> <div class="text-sm font-mono text-white"> - <span id="yearMinLabel">1980</span> — <span id="yearMaxLabel">{{ current_year }}</span> + <span id="yearMinLabel">{{ year_range_defaults.min }}</span> — <span id="yearMaxLabel">{{ current_year }}</span> </div> </div> @@ -103,10 +103,10 @@ <h2 class="text-3xl font-bold text-white mb-2">Preferences</h2> <div id="yearSliderTrack" class="absolute h-1.5 bg-white rounded-full" style="left: 0%; right: 0%;"></div> <!-- Hidden Dual Sliders --> - <input type="range" id="yearMin" min="1970" max="2026" value="1980" step="1" + <input type="range" id="yearMin" min="{{ year_range_defaults.min }}" max="{{ year_range_defaults.max }}" value="{{ year_range_defaults.min }}" step="1" class="absolute w-full appearance-none bg-transparent pointer-events-none z-30 [&::-webkit-slider-thumb]:pointer-events-auto [&::-webkit-slider-thumb]:w-5 [&::-webkit-slider-thumb]:h-5 [&::-webkit-slider-thumb]:rounded-full [&::-webkit-slider-thumb]:bg-white [&::-webkit-slider-thumb]:border-2 [&::-webkit-slider-thumb]:border-black [&::-webkit-slider-thumb]:appearance-none [&::-webkit-slider-thumb]:cursor-pointer [&::-webkit-slider-thumb]:shadow-lg [&::-moz-range-thumb]:pointer-events-auto [&::-moz-range-thumb]:w-5 [&::-moz-range-thumb]:h-5 [&::-moz-range-thumb]:border-2 [&::-moz-range-thumb]:border-black [&::-moz-range-thumb]:rounded-full [&::-moz-range-thumb]:bg-white [&::-moz-range-thumb]:appearance-none [&::-moz-range-thumb]:cursor-pointer [&::-moz-range-thumb]:shadow-lg"> - <input type="range" id="yearMax" min="1970" max="2026" value="2026" step="1" + <input type="range" id="yearMax" min="{{ year_range_defaults.min }}" max="{{ year_range_defaults.max }}" value="{{ year_range_defaults.max }}" step="1" class="absolute w-full appearance-none bg-transparent pointer-events-none z-30 [&::-webkit-slider-thumb]:pointer-events-auto [&::-webkit-slider-thumb]:w-5 [&::-webkit-slider-thumb]:h-5 [&::-webkit-slider-thumb]:rounded-full [&::-webkit-slider-thumb]:bg-white [&::-webkit-slider-thumb]:border-2 [&::-webkit-slider-thumb]:border-black [&::-webkit-slider-thumb]:appearance-none [&::-webkit-slider-thumb]:cursor-pointer [&::-webkit-slider-thumb]:shadow-lg [&::-moz-range-thumb]:pointer-events-auto [&::-moz-range-thumb]:w-5 [&::-moz-range-thumb]:h-5 [&::-moz-range-thumb]:border-2 [&::-moz-range-thumb]:border-black [&::-moz-range-thumb]:rounded-full [&::-moz-range-thumb]:bg-white [&::-moz-range-thumb]:appearance-none [&::-moz-range-thumb]:cursor-pointer [&::-moz-range-thumb]:shadow-lg"> </div> @@ -346,6 +346,35 @@ <h2 class="text-3xl font-bold text-white mb-2">Preferences</h2> </div> <div class="border-t border-white/10"></div> + <!-- Watch History Source --> + <div class="space-y-4"> + <div class="flex flex-col sm:flex-row sm:items-center sm:justify-between gap-3"> + <div> + <label class="block text-sm font-medium text-slate-400 uppercase tracking-wider">Watch History Source</label> + <p class="text-xs text-slate-500 mt-1">Where Watchly should read your history from. Connect Trakt or Simkl in <button type="button" class="account-link text-slate-300 hover:text-white underline">Accounts</button> to use them.</p> + </div> + <div class="inline-flex items-center bg-neutral-900/60 border border-white/10 rounded-xl p-1 backdrop-blur-sm self-start sm:self-auto" role="group" aria-label="Watch history source"> + <button type="button" + class="source-btn px-4 py-2 text-sm font-medium rounded-lg transition-all outline-none focus:outline-none bg-white/10 text-white border border-white/20 shadow-sm flex items-center gap-1.5" + data-source-btn="stremio">Stremio</button> + <button type="button" + class="source-btn px-4 py-2 text-sm font-medium rounded-lg transition-all outline-none focus:outline-none text-slate-400 hover:text-white hover:bg-white/5 border border-transparent flex items-center gap-1.5" + data-source-btn="trakt"> + <span class="w-1.5 h-1.5 rounded-full bg-slate-600" data-source-pip="trakt"></span> + Trakt + </button> + <button type="button" + class="source-btn px-4 py-2 text-sm font-medium rounded-lg transition-all outline-none focus:outline-none text-slate-400 hover:text-white hover:bg-white/5 border border-transparent flex items-center gap-1.5" + data-source-btn="simkl"> + <span class="w-1.5 h-1.5 rounded-full bg-slate-600" data-source-pip="simkl"></span> + Simkl + </button> + </div> + </div> + <input type="hidden" id="watchHistorySource" value="stremio"> + </div> + <div class="border-t border-white/10"></div> + <!-- Gemini AI Integration --> <div class="space-y-4"> <label class="block text-sm font-medium text-slate-400 uppercase tracking-wider">Gemini AI diff --git a/app/templates/components/section_login.html b/app/templates/components/section_login.html index 1443106..0c1dc0b 100644 --- a/app/templates/components/section_login.html +++ b/app/templates/components/section_login.html @@ -1,97 +1,230 @@ -<!-- SECTION 1: LOGIN --> +<!-- SECTION 1: ACCOUNTS --> <section id="sect-login" class="space-y-6 hidden animate-fade-in"> - <div class="mb-8"> - <h2 class="text-3xl font-bold text-white mb-2">Connect to Stremio</h2> - <p class="text-slate-400">Log in to your Stremio account to enable Watchly to read your library.</p> + <div class="mb-6"> + <h2 class="text-3xl font-bold text-white mb-2">Accounts</h2> + <p class="text-slate-400">Stremio is required so Watchly can read your library. Trakt and Simkl are optional — they add explicit ratings and rewatches that sharpen recommendations.</p> </div> - <!-- Logged In Status (Hidden by default, shown after login) --> - <div id="loginStatusSection" class="hidden bg-neutral-900/60 border border-white/10 rounded-2xl p-6 md:p-8 backdrop-blur-sm shadow-xl shadow-black/20"> - <div class="flex items-center justify-between gap-4"> - <div class="flex items-center gap-3 flex-grow min-w-0"> - <div class="w-10 h-10 rounded-full bg-white text-black ring-1 ring-white/10 flex items-center justify-center font-bold text-sm flex-shrink-0" - id="loginStatusAvatar"> - <!-- Avatar initials will be generated from email --> - </div> - <div class="flex-grow min-w-0"> - <div class="text-xs text-slate-500 mb-0.5">Logged in as</div> - <div class="text-sm text-white font-medium truncate" id="loginStatusEmail"> - <!-- Email will be inserted here --> + <!-- Intro / Privacy note --> + <div class="bg-neutral-900/40 border border-white/5 rounded-2xl p-5 text-xs text-slate-400 leading-relaxed"> + <div class="flex items-start gap-3"> + <svg class="w-4 h-4 mt-0.5 text-slate-500 flex-shrink-0" fill="none" stroke="currentColor" viewBox="0 0 24 24"> + <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M13 16h-1v-4h-1m1-4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z"/> + </svg> + <div> + Credentials and OAuth tokens are encrypted at rest and only used to read your watch history. You can disconnect any provider at any time. The watch history source is selected on the next page. + </div> + </div> + </div> + + <!-- STREMIO --> + <div id="provider-stremio" class="bg-neutral-900/60 border border-white/10 rounded-2xl p-6 md:p-8 backdrop-blur-sm shadow-xl shadow-black/20"> + <div class="flex items-start justify-between gap-4 mb-5"> + <div class="flex items-center gap-3"> + <img src="https://stremio.com/website/stremio-logo-small.png" class="w-8 h-8" alt="Stremio"> + <div> + <div class="flex items-center gap-2"> + <h3 class="text-lg font-semibold text-white">Stremio</h3> + <span class="text-[10px] uppercase tracking-wider px-2 py-0.5 rounded-full bg-cyan-500/15 text-cyan-300 border border-cyan-400/20">Required</span> </div> + <p class="text-xs text-slate-500 mt-0.5">Reads your Stremio library to build recommendations.</p> </div> </div> - <button type="button" id="loginStatusLogoutBtn" - class="flex-shrink-0 bg-red-600 hover:bg-red-700 text-white font-medium py-2.5 px-4 rounded-xl transition border border-red-700 shadow-lg shadow-red-900/20 flex items-center justify-center gap-2"> - <svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"> - <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" - d="M17 16l4-4m0 0l-4-4m4 4H7m6 4v1a3 3 0 01-3 3H6a3 3 0 01-3-3V7a3 3 0 013-3h4a3 3 0 013 3v1"> - </path> - </svg> - <span>Logout</span> - </button> + <span class="w-2 h-2 rounded-full bg-slate-500 mt-2 flex-shrink-0" data-account-dot="stremio"></span> </div> - </div> - <div id="loginFormCard" class="bg-neutral-900/60 border border-white/10 rounded-2xl p-6 md:p-8 backdrop-blur-sm shadow-xl shadow-black/20"> - <button type="button" id="stremioLoginBtn" - class="w-full bg-stremio text-white font-medium py-4 rounded-xl transition flex items-center justify-center gap-3 border border-stremio-border shadow-lg shadow-stremio/20 group hover:bg-white hover:text-black hover:border-white/10"> - <img src="https://stremio.com/website/stremio-logo-small.png" - class="w-6 h-6 group-hover:scale-110 transition-transform" alt="Stremio"> - <span id="stremioLoginText" class="text-lg">Login with Stremio</span> - </button> - - <input type="hidden" id="authKey"> - - <!-- Divider --> - <div id="emailPwdDivider" class="flex items-center gap-3 my-6"> - <div class="h-px bg-white/10 w-full"></div> - <div class="text-xs text-slate-500">or</div> - <div class="h-px bg-white/10 w-full"></div> + <!-- Disconnected view --> + <div id="loginFormCard" data-provider-view="disconnected" data-provider-for="stremio"> + <button type="button" id="stremioLoginBtn" + class="w-full bg-stremio text-white font-medium py-4 rounded-xl transition flex items-center justify-center gap-3 border border-stremio-border shadow-lg shadow-stremio/20 group hover:bg-white hover:text-black hover:border-white/10"> + <img src="https://stremio.com/website/stremio-logo-small.png" + class="w-6 h-6 group-hover:scale-110 transition-transform" alt="Stremio"> + <span id="stremioLoginText" class="text-lg">Login with Stremio</span> + </button> + + <input type="hidden" id="authKey"> + + <div id="emailPwdDivider" class="flex items-center gap-3 my-6"> + <div class="h-px bg-white/10 w-full"></div> + <div class="text-xs text-slate-500">or</div> + <div class="h-px bg-white/10 w-full"></div> + </div> + + <div id="emailPwdSection" class="grid gap-3"> + <label class="text-xs text-slate-400">Email</label> + <input id="emailInput" type="email" autocomplete="email" inputmode="email" + spellcheck="false" required placeholder="you@example.com" + class="w-full bg-neutral-900 border border-slate-700 rounded-xl px-4 py-3.5 text-white placeholder-slate-500 focus:ring-2 focus:ring-white/20 focus:border-white/30 outline-none transition-all"> + <label class="text-xs text-slate-400">Password</label> + <div class="relative"> + <input id="passwordInput" type="password" autocomplete="current-password" + placeholder="Your Stremio password" + class="w-full bg-neutral-900 border border-slate-700 rounded-xl pl-4 pr-12 py-3.5 text-white placeholder-slate-500 focus:ring-2 focus:ring-white/20 focus:border-white/30 outline-none transition-all"> + <button type="button" + class="toggle-btn absolute right-2 top-1/2 -translate-y-1/2 bg-white/10 hover:bg-white/20 text-white p-2 rounded-lg" + aria-label="Show password" title="Show" data-target="passwordInput"> + <svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" + stroke-width="2" stroke-linecap="round" stroke-linejoin="round"> + <path d="M1 12s4-7 11-7 11 7 11 7-4 7-11 7-11-7-11-7z" /> + <circle cx="12" cy="12" r="3" /> + </svg> + </button> + </div> + <button type="button" id="emailPwdContinueBtn" + class="mt-2 w-full bg-white text-black hover:bg-white/90 font-medium py-3 rounded-xl transition border border-white/10 flex items-center justify-center gap-2"> + <span class="btn-text">Continue with Email</span> + <div class="loader hidden w-5 h-5 border-2 border-black/30 border-t-black rounded-full animate-spin"></div> + </button> + </div> + + <div id="emailPwdError" class="hidden mt-3 p-3 bg-red-500/10 border border-red-500/20 rounded-xl text-red-200 text-sm"></div> + + <div id="emailPwdDisclaimer" + class="mt-4 text-xs leading-relaxed bg-yellow-500/10 border border-yellow-500/30 text-yellow-200 rounded-xl p-3"> + <strong class="text-yellow-300">Why email & password?</strong> + <span class="block mt-1">We store your credentials securely so we can refresh your Stremio auth key automatically. This avoids expired keys and keeps your addon working without manual re-login.</span> + <span class="block mt-2">Prefer not to share your password? Use the Stremio button above to supply an auth key instead — note that auth keys can expire and may require periodic re-authentication.</span> + </div> </div> - <!-- Email/Password Login --> - <div id="emailPwdSection" class="grid gap-3"> - <label class="text-xs text-slate-400">Email</label> - <input id="emailInput" type="email" autocomplete="email" inputmode="email" - spellcheck="false" required placeholder="you@example.com" - class="w-full bg-neutral-900 border border-slate-700 rounded-xl px-4 py-3.5 text-white placeholder-slate-500 focus:ring-2 focus:ring-white/20 focus:border-white/30 outline-none transition-all"> - <label class="text-xs text-slate-400">Password</label> - <div class="relative"> - <input id="passwordInput" type="password" autocomplete="current-password" - placeholder="Your Stremio password" - class="w-full bg-neutral-900 border border-slate-700 rounded-xl pl-4 pr-12 py-3.5 text-white placeholder-slate-500 focus:ring-2 focus:ring-white/20 focus:border-white/30 outline-none transition-all"> - <button type="button" - class="toggle-btn absolute right-2 top-1/2 -translate-y-1/2 bg-white/10 hover:bg-white/20 text-white p-2 rounded-lg" - aria-label="Show password" title="Show" data-target="passwordInput"> - <svg class="w-4 h-4" viewBox="0 0 24 24" fill="none" stroke="currentColor" - stroke-width="2" stroke-linecap="round" stroke-linejoin="round"> - <path d="M1 12s4-7 11-7 11 7 11 7-4 7-11 7-11-7-11-7z" /> - <circle cx="12" cy="12" r="3" /> + <!-- Connected view --> + <div data-provider-view="connected" data-provider-for="stremio" class="hidden"> + <div class="flex items-center justify-between gap-4"> + <div class="flex items-center gap-3 flex-grow min-w-0"> + <div class="w-10 h-10 rounded-full bg-white text-black ring-1 ring-white/10 flex items-center justify-center font-bold text-sm flex-shrink-0" + id="loginStatusAvatar"></div> + <div class="flex-grow min-w-0"> + <div class="text-xs text-green-400 mb-0.5 flex items-center gap-1.5"> + <svg class="w-3 h-3" fill="currentColor" viewBox="0 0 20 20"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-9.293a1 1 0 00-1.414-1.414L9 10.586 7.707 9.293a1 1 0 00-1.414 1.414l2 2a1 1 0 001.414 0l4-4z" clip-rule="evenodd"/></svg> + Connected + </div> + <div class="text-sm text-white font-medium truncate" id="loginStatusEmail"></div> + </div> + </div> + <button type="button" id="loginStatusLogoutBtn" + class="flex-shrink-0 bg-red-600/90 hover:bg-red-600 text-white font-medium py-2.5 px-4 rounded-xl transition border border-red-500/40 shadow-lg shadow-red-900/20 flex items-center gap-2"> + <svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"> + <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" + d="M17 16l4-4m0 0l-4-4m4 4H7m6 4v1a3 3 0 01-3 3H6a3 3 0 01-3-3V7a3 3 0 013-3h4a3 3 0 013 3v1"/> </svg> + <span>Logout</span> </button> </div> - <button type="button" id="emailPwdContinueBtn" - class="mt-2 w-full bg-white text-black hover:bg-white/90 font-medium py-3 rounded-xl transition border border-white/10 flex items-center justify-center gap-2"> - <span class="btn-text">Continue with Email</span> - <div class="loader hidden w-5 h-5 border-2 border-black/30 border-t-black rounded-full animate-spin"> + </div> + </div> + + <!-- TRAKT --> + <div id="provider-trakt" class="bg-neutral-900/60 border border-white/10 rounded-2xl p-6 md:p-8 backdrop-blur-sm shadow-xl shadow-black/20"> + <div class="flex items-start justify-between gap-4 mb-5"> + <div class="flex items-center gap-3"> + <div class="w-8 h-8 rounded-lg bg-red-500/15 border border-red-400/20 flex items-center justify-center text-red-400"> + <svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"> + <path d="M12 0C5.4 0 0 5.4 0 12s5.4 12 12 12 12-5.4 12-12S18.6 0 12 0zm-1.4 17.5L4.7 11.6l1.1-1.1 4.8 4.8 7.6-7.6 1.1 1.1-8.7 8.7zm0-3.4L6.7 10.2l1.1-1.1 2.8 2.8 5.6-5.6 1.1 1.1-6.7 6.7z"/> + </svg> </div> + <div> + <div class="flex items-center gap-2"> + <h3 class="text-lg font-semibold text-white">Trakt</h3> + <span class="text-[10px] uppercase tracking-wider px-2 py-0.5 rounded-full bg-slate-500/15 text-slate-300 border border-slate-400/20">Optional</span> + </div> + <p class="text-xs text-slate-500 mt-0.5">Adds explicit ratings, rewatches, and stronger taste signal.</p> + </div> + </div> + <span class="w-2 h-2 rounded-full bg-slate-500 mt-2 flex-shrink-0" data-account-dot="trakt"></span> + </div> + + <!-- Disconnected view --> + <div data-provider-view="disconnected" data-provider-for="trakt"> + <button type="button" id="traktLoginBtn" + class="w-full bg-red-600/20 hover:bg-red-600/30 border border-red-500/30 text-red-300 hover:text-red-200 font-medium py-4 rounded-xl transition flex items-center justify-center gap-3"> + <svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"> + <path d="M12 0C5.4 0 0 5.4 0 12s5.4 12 12 12 12-5.4 12-12S18.6 0 12 0zm-1.4 17.5L4.7 11.6l1.1-1.1 4.8 4.8 7.6-7.6 1.1 1.1-8.7 8.7z"/> + </svg> + <span class="text-base">Login with Trakt</span> </button> + <p class="mt-4 text-xs text-slate-500 leading-relaxed"> + Trakt tracks every play and rating you make. Watchly uses your ratings to weight loved vs. liked items differently — much sharper than play counts alone. Pick Trakt as your watch history source on the next page once connected. + </p> </div> - <!-- Inline error for email/password login --> - <div id="emailPwdError" class="hidden mt-3 p-3 bg-red-500/10 border border-red-500/20 rounded-xl text-red-200 text-sm"> + <!-- Connected view --> + <div data-provider-view="connected" data-provider-for="trakt" class="hidden"> + <div class="flex items-center justify-between gap-4"> + <div class="flex items-center gap-3"> + <div class="w-10 h-10 rounded-full bg-red-500/15 text-red-300 ring-1 ring-red-400/20 flex items-center justify-center"> + <svg class="w-5 h-5" fill="currentColor" viewBox="0 0 20 20"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-9.293a1 1 0 00-1.414-1.414L9 10.586 7.707 9.293a1 1 0 00-1.414 1.414l2 2a1 1 0 001.414 0l4-4z" clip-rule="evenodd"/></svg> + </div> + <div> + <div class="text-xs text-green-400 mb-0.5">Connected</div> + <div class="text-sm text-white font-medium" id="traktStatus">Trakt</div> + </div> + </div> + <button type="button" id="traktLogoutBtn" + class="bg-red-600/20 hover:bg-red-600/30 border border-red-500/30 text-red-300 hover:text-red-200 text-sm font-medium py-2 px-4 rounded-lg transition"> + Disconnect + </button> + </div> </div> + </div> - <!-- Disclaimer --> - <div id="emailPwdDisclaimer" - class="mt-4 text-xs leading-relaxed bg-yellow-500/10 border border-yellow-500/30 text-yellow-200 rounded-xl p-3"> - <strong class="text-yellow-300">Why email & password?</strong> - <span class="block mt-1">We store your credentials securely to generate a fresh Stremio auth - key automatically when needed. This avoids expired keys and keeps your addon working - without manual re-login.</span> - <span class="block mt-2">Prefer not to share your password? Use the Stremio login above to - supply an auth key. Note: auth keys can expire and may require periodic - re-authentication.</span> + <!-- SIMKL --> + <div id="provider-simkl" class="bg-neutral-900/60 border border-white/10 rounded-2xl p-6 md:p-8 backdrop-blur-sm shadow-xl shadow-black/20"> + <div class="flex items-start justify-between gap-4 mb-5"> + <div class="flex items-center gap-3"> + <div class="w-8 h-8 rounded-lg bg-blue-500/15 border border-blue-400/20 flex items-center justify-center text-blue-400"> + <svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"> + <path d="M12 0C5.4 0 0 5.4 0 12s5.4 12 12 12 12-5.4 12-12S18.6 0 12 0zm0 4.5c1.4 0 2.5 1.1 2.5 2.5S13.4 9.5 12 9.5 9.5 8.4 9.5 7s1.1-2.5 2.5-2.5zm5 13H7v-1.5c0-1.7 3.3-2.5 5-2.5s5 .8 5 2.5V17.5z"/> + </svg> + </div> + <div> + <div class="flex items-center gap-2"> + <h3 class="text-lg font-semibold text-white">Simkl</h3> + <span class="text-[10px] uppercase tracking-wider px-2 py-0.5 rounded-full bg-slate-500/15 text-slate-300 border border-slate-400/20">Optional</span> + </div> + <p class="text-xs text-slate-500 mt-0.5">Detailed show tracking with ratings and watch dates.</p> + </div> + </div> + <span class="w-2 h-2 rounded-full bg-slate-500 mt-2 flex-shrink-0" data-account-dot="simkl"></span> </div> + + <!-- Disconnected view --> + <div data-provider-view="disconnected" data-provider-for="simkl"> + <button type="button" id="simklLoginBtn" + class="w-full bg-blue-600/20 hover:bg-blue-600/30 border border-blue-500/30 text-blue-300 hover:text-blue-200 font-medium py-4 rounded-xl transition flex items-center justify-center gap-3"> + <svg class="w-5 h-5" fill="currentColor" viewBox="0 0 24 24"> + <path d="M12 0C5.4 0 0 5.4 0 12s5.4 12 12 12 12-5.4 12-12S18.6 0 12 0zm0 4.5c1.4 0 2.5 1.1 2.5 2.5S13.4 9.5 12 9.5 9.5 8.4 9.5 7s1.1-2.5 2.5-2.5z"/> + </svg> + <span class="text-base">Login with Simkl</span> + </button> + <p class="mt-4 text-xs text-slate-500 leading-relaxed"> + Simkl is especially strong for series — it tracks individual episodes and adds rich rating signal. Pick Simkl as your watch history source on the next page once connected. + </p> + </div> + + <!-- Connected view --> + <div data-provider-view="connected" data-provider-for="simkl" class="hidden"> + <div class="flex items-center justify-between gap-4"> + <div class="flex items-center gap-3"> + <div class="w-10 h-10 rounded-full bg-blue-500/15 text-blue-300 ring-1 ring-blue-400/20 flex items-center justify-center"> + <svg class="w-5 h-5" fill="currentColor" viewBox="0 0 20 20"><path fill-rule="evenodd" d="M10 18a8 8 0 100-16 8 8 0 000 16zm3.707-9.293a1 1 0 00-1.414-1.414L9 10.586 7.707 9.293a1 1 0 00-1.414 1.414l2 2a1 1 0 001.414 0l4-4z" clip-rule="evenodd"/></svg> + </div> + <div> + <div class="text-xs text-green-400 mb-0.5">Connected</div> + <div class="text-sm text-white font-medium" id="simklSyncStatus">Simkl</div> + </div> + </div> + <button type="button" id="simklSyncLogoutBtn" + class="bg-blue-600/20 hover:bg-blue-600/30 border border-blue-500/30 text-blue-300 hover:text-blue-200 text-sm font-medium py-2 px-4 rounded-lg transition"> + Disconnect + </button> + </div> + </div> + </div> + + <!-- Footer / Next --> + <div class="pt-4 flex justify-end"> + <button type="button" id="accountsNextBtn" + class="bg-white text-black hover:bg-white/90 border border-white/10 px-8 py-3 rounded-xl font-medium transition hover:-translate-y-0.5 hover:ring-2 hover:ring-black/10 focus:outline-none focus:ring-2 focus:ring-black/20 active:translate-y-0 disabled:opacity-50 disabled:cursor-not-allowed disabled:hover:translate-y-0 disabled:hover:ring-0" + disabled>Next: Configure Options →</button> </div> </section> diff --git a/app/templates/components/section_welcome.html b/app/templates/components/section_welcome.html index 3fb51c2..e5818e8 100644 --- a/app/templates/components/section_welcome.html +++ b/app/templates/components/section_welcome.html @@ -63,17 +63,17 @@ <h1 class="text-4xl md:text-5xl font-bold text-white"> <div class="grid md:grid-cols-2 lg:grid-cols-3 gap-4 max-w-5xl mx-auto"> <!-- Feature 1 --> <div - class="group bg-gradient-to-br from-slate-900/50 to-slate-800/30 border border-slate-700/50 hover:border-blue-500/50 rounded-xl p-4 transition-all hover:shadow-lg hover:shadow-blue-900/10"> + class="group bg-gradient-to-br from-slate-900/50 to-slate-800/30 border border-slate-700/50 hover:border-rose-500/50 rounded-xl p-4 transition-all hover:shadow-lg hover:shadow-rose-900/10"> <div - class="w-10 h-10 bg-blue-500/10 text-blue-400 rounded-lg flex items-center justify-center mb-3 group-hover:scale-110 transition-transform"> + class="w-10 h-10 bg-rose-500/10 text-rose-400 rounded-lg flex items-center justify-center mb-3 group-hover:scale-110 transition-transform"> <svg class="w-5 h-5" fill="none" stroke="currentColor" viewBox="0 0 24 24"> <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" - d="M13 10V3L4 14h7v7l9-11h-7z"></path> + d="M13.828 10.172a4 4 0 015.656 5.656l-3 3a4 4 0 01-5.656-5.656m1.414-2.828a4 4 0 00-5.656 0l-3 3a4 4 0 005.656 5.656l1.414-1.414"></path> </svg> </div> - <h3 class="text-base font-bold text-white mb-1">Smart Recommendations</h3> + <h3 class="text-base font-bold text-white mb-1">Trakt & Simkl Integration</h3> <p class="text-xs text-slate-400 leading-relaxed"> - AI-powered suggestions based on your watch history, library and your reactions. + Optional. Connect Trakt or Simkl for explicit ratings and rewatches that sharpen recommendations. </p> </div> diff --git a/app/templates/components/sidebar.html b/app/templates/components/sidebar.html index 865c3fe..3b0bcee 100644 --- a/app/templates/components/sidebar.html +++ b/app/templates/components/sidebar.html @@ -84,11 +84,11 @@ <h1 class="font-bold text-2xl text-transparent bg-clip-text bg-gradient-to-r fro class="w-8 h-8 rounded-lg bg-cyan-500/10 text-cyan-400 flex items-center justify-center border border-cyan-400/20 flex-shrink-0 transition-all group-hover:scale-105 group-hover:bg-cyan-500/15 group-hover:border-cyan-400/30 group-hover:shadow-lg group-hover:shadow-cyan-900/10"> <svg class="w-4 h-4" fill="none" stroke="currentColor" viewBox="0 0 24 24"> <path stroke-linecap="round" stroke-linejoin="round" stroke-width="2" - d="M11 16l-4-4m0 0l4-4m-4 4h14m-5 4v1a3 3 0 01-3 3H6a3 3 0 01-3-3V7a3 3 0 013-3h7a3 3 0 013 3v1"> + d="M13.828 10.172a4 4 0 015.656 5.656l-3 3a4 4 0 01-5.656-5.656m1.414-2.828a4 4 0 00-5.656 0l-3 3a4 4 0 005.656 5.656l1.414-1.414"> </path> </svg> </div> - <span>Login to Stremio</span> + <span>Accounts</span> </button> <button id="nav-config" diff --git a/app/utils/catalog.py b/app/utils/catalog.py deleted file mode 100644 index e5fc6b6..0000000 --- a/app/utils/catalog.py +++ /dev/null @@ -1,102 +0,0 @@ -from typing import Any - -from app.core.constants import DISCOVER_ONLY_EXTRA -from app.core.settings import UserSettings -from app.services.profile.integration import ProfileIntegration -from app.services.stremio.service import StremioBundle -from app.services.user_cache import user_cache - - -def get_catalogs_from_config( - user_settings: UserSettings, - cat_id: str, - default_name: str, - default_movie: bool, - default_series: bool, -): - catalogs = [] - config = next((c for c in user_settings.catalogs if c.id == cat_id), None) - - if config and config.enabled: - name = config.name if config and config.name else default_name - enabled_movie = getattr(config, "enabled_movie", default_movie) if config else default_movie - enabled_series = getattr(config, "enabled_series", default_series) if config else default_series - display_at_home = getattr(config, "display_at_home", True) if config else True - - extra = DISCOVER_ONLY_EXTRA if not display_at_home else [] - - if enabled_movie: - catalogs.append({"type": "movie", "id": cat_id, "name": name, "extra": extra}) - if enabled_series: - catalogs.append({"type": "series", "id": cat_id, "name": name, "extra": extra}) - return catalogs - - -async def cache_profile_and_watched_sets( - token: str, - content_type: str, - integration_service: ProfileIntegration, - library_items: dict, - bundle: StremioBundle, - auth_key: str, -): - """ - Build and cache profile and watched sets for a user and content type. - Uses the centralized UserCacheService for caching. - """ - ( - profile, - watched_tmdb, - watched_imdb, - ) = await integration_service.build_profile_incremental(library_items, content_type, token, bundle, auth_key) - - await user_cache.set_profile_and_watched_sets(token, content_type, profile, watched_tmdb, watched_imdb) - return profile, watched_tmdb, watched_imdb - - -def get_config_id(catalog) -> str | None: - catalog_id = catalog.get("id", "") - if catalog_id.startswith("watchly.theme."): - return "watchly.theme" - if catalog_id.startswith("watchly.loved."): - return "watchly.loved" - if catalog_id.startswith("watchly.watched."): - return "watchly.watched" - return catalog_id - - -def sort_catalogs(catalogs: list[dict[str, Any]], user_settings: UserSettings) -> list[dict[str, Any]]: - """ - Sort catalogs according to user settings and sorting order choice. - - Sorting Orders: - - default: Interleaved (by category priority, then movie then series) - - movies_first: Group all movies first, then all series - - series_first: Group all series first, then all movies - """ - if not user_settings: - return catalogs - - # Get the original order index from user settings for each catalog category - order_map = {c.id: i for i, c in enumerate(user_settings.catalogs)} - - # Base sorting key: setting index (priority) - def get_setting_index(cat): - return order_map.get(get_config_id(cat), 999) - - sorting_order = getattr(user_settings, "sorting_order", "default") - - if sorting_order == "movies_first": - # Group movies first, then series - # movies: type_priority=0, series: type_priority=1 - sorted_catalogs = sorted(catalogs, key=lambda x: (0 if x.get("type") == "movie" else 1, get_setting_index(x))) - elif sorting_order == "series_first": - # Group series first, then movies - # series: type_priority=0, movies: type_priority=1 - sorted_catalogs = sorted(catalogs, key=lambda x: (0 if x.get("type") == "series" else 1, get_setting_index(x))) - else: - # Default: Interleaved (by category priority) - # Python's sorted is stable, preserving movie then series within same priority - sorted_catalogs = sorted(catalogs, key=get_setting_index) - - return sorted_catalogs diff --git a/main.py b/main.py index 7709a1b..f301320 100644 --- a/main.py +++ b/main.py @@ -1,5 +1,7 @@ import os +from loguru import logger + from app.core.app import app # noqa: F401 from app.core.config import settings @@ -8,4 +10,5 @@ PORT = os.getenv("PORT", settings.PORT) reload = settings.APP_ENV == "development" + logger.info(f"Starting Watchly: APP_ENV={settings.APP_ENV} reload={reload} port={PORT}") uvicorn.run("app.core.app:app", host="0.0.0.0", port=int(PORT), reload=reload) diff --git a/tests/test_catalog_endpoint.py b/tests/test_catalog_endpoint.py new file mode 100644 index 0000000..b6ebca1 --- /dev/null +++ b/tests/test_catalog_endpoint.py @@ -0,0 +1,32 @@ +from fastapi.testclient import TestClient + +from app.core.app import app +from app.services.recommendation.catalog_service import catalog_service + +client = TestClient(app) + + +def test_catalog_endpoint_keeps_cache_header_for_non_empty_results(monkeypatch): + async def fake_get_catalog(token: str, content_type: str, catalog_id: str): + return {"metas": [{"id": "tt1234567", "type": "movie", "name": "Example"}]}, {"Cache-Control": "public"} + + monkeypatch.setattr(catalog_service, "get_catalog", fake_get_catalog) + + response = client.get("/abc/catalog/movie/watchly.rec.json") + + assert response.status_code == 200 + assert response.headers["Cache-Control"] == "public" + assert response.json()["metas"] + + +def test_catalog_endpoint_marks_empty_results_as_no_cache(monkeypatch): + async def fake_get_catalog(token: str, content_type: str, catalog_id: str): + return {"metas": []}, {"Cache-Control": "public"} + + monkeypatch.setattr(catalog_service, "get_catalog", fake_get_catalog) + + response = client.get("/abc/catalog/movie/watchly.rec.json") + + assert response.status_code == 200 + assert response.headers["Cache-Control"] == "no-cache" + assert response.json() == {"metas": []} diff --git a/tests/test_configure_page.py b/tests/test_configure_page.py new file mode 100644 index 0000000..9983722 --- /dev/null +++ b/tests/test_configure_page.py @@ -0,0 +1,27 @@ +import importlib + +from fastapi.testclient import TestClient + +from app.core.app import app + +client = TestClient(app) +app_module = importlib.import_module("app.core.app") + + +def test_configure_page_bootstraps_current_year_and_year_defaults(monkeypatch): + async def fake_fetch_languages_list(): + return [{"iso_639_1": "en-US", "language": "English", "country": "US"}] + + async def fake_count_users(): + return 7 + + monkeypatch.setattr(app_module, "fetch_languages_list", fake_fetch_languages_list) + monkeypatch.setattr(app_module.token_store, "count_users", fake_count_users) + + response = client.get("/configure") + + assert response.status_code == 200 + html = response.text + assert 'window.YEAR_RANGE_DEFAULTS = {"min": 1970, "max": ' in html + assert 'id="yearMin" min="1970"' in html + assert 'id="yearMax" min="1970"' in html diff --git a/tests/test_token_store_migration.py b/tests/test_token_store_migration.py new file mode 100644 index 0000000..601e201 --- /dev/null +++ b/tests/test_token_store_migration.py @@ -0,0 +1,44 @@ +import asyncio +import json + +from app.services.token_store import TokenStore + + +def test_migrate_poster_rating_preserves_migrated_api_key(monkeypatch): + store = TokenStore() + writes: list[dict] = [] + + async def fake_set(key: str, value: str, ttl=None): + writes.append({"key": key, "value": value, "ttl": ttl}) + return True + + monkeypatch.setattr("app.services.token_store.redis_service.set", fake_set) + + payload = { + "settings": { + "rpdb_key": "plain-api-key", + } + } + + updated = asyncio.run(store._migrate_poster_rating_format_raw("user123", "watchly:token:user123", payload)) + + assert updated is not None + assert "rpdb_key" not in updated["settings"] + assert updated["settings"]["poster_rating"]["provider"] == "rpdb" + assert updated["settings"]["poster_rating"]["api_key"] is not None + assert updated["settings"]["poster_rating"]["api_key"] != "plain-api-key" + assert writes + + stored_payload = json.loads(writes[0]["value"]) + assert stored_payload["settings"]["poster_rating"]["api_key"] is not None + + +def test_token_request_defaults_match_user_settings_defaults(): + from app.api.models.tokens import TokenRequest + from app.core.settings import UserSettings, get_default_year_max + + token_request = TokenRequest() + user_settings = UserSettings(catalogs=[]) + + assert token_request.year_min == user_settings.year_min == 1970 + assert token_request.year_max == user_settings.year_max == get_default_year_max()