diff --git a/src/scriptworker/utils.py b/src/scriptworker/utils.py index 9e16166b..f4679cbb 100644 --- a/src/scriptworker/utils.py +++ b/src/scriptworker/utils.py @@ -521,7 +521,7 @@ def format_json(data): # is omitted we don't actually ever return None (because on failure we raise an Exception) @overload def load_json_or_yaml( - string: str, is_path: Optional[bool] = ..., file_type: Optional[str] = ..., exception: Type[BaseException] = ..., message: str = ... + string: str, is_path: Optional[bool] = ..., file_type: Optional[str] = ..., exception: Type[BaseException] = ..., message: Optional[str] = ... ) -> Dict[str, Any]: # pragma: no cover ... @@ -532,7 +532,7 @@ def load_json_or_yaml( is_path: Optional[bool] = False, file_type: Optional[str] = "json", exception: Optional[Type[BaseException]] = ScriptWorkerTaskException, - message: str = "Failed to load %(file_type)s: %(exc)s", + message: Optional[str] = None, ) -> Optional[Dict[str, Any]]: # pragma: no cover ... @@ -542,7 +542,7 @@ def load_json_or_yaml( is_path: Optional[bool] = False, file_type: Optional[str] = "json", exception: Optional[Type[BaseException]] = ScriptWorkerTaskException, - message: str = "Failed to load %(file_type)s: %(exc)s", + message: Optional[str] = None, ) -> Optional[Dict[str, Any]]: """Load json or yaml from a filehandle or string, and raise a custom exception on failure. @@ -552,8 +552,10 @@ def load_json_or_yaml( file_type (str, optional): either "json" or "yaml". Defaults to "json". exception (exception, optional): the exception to raise on failure. If None, don't raise an exception. Defaults to ScriptWorkerTaskException. - message (str, optional): the message to use for the exception. - Defaults to "Failed to load %(file_type)s: %(exc)s" + message (str, optional): override the exception message. Supports the + ``%(file_type)s``, ``%(exc)s``, and ``%(path)s`` placeholders (the + latter is empty when ``is_path=False``). Defaults to a message that + includes the path when ``is_path=True``, otherwise one that omits it. Returns: dict: the data from the string. @@ -578,7 +580,9 @@ def load_json_or_yaml( return contents except (OSError, ValueError, yaml.scanner.ScannerError) as exc: if exception is not None: - repl_dict = {"exc": str(exc), "file_type": file_type} + if message is None: + raise exception(f"Failed to load {file_type} from {string}: {exc}" if is_path else f"Failed to load {file_type}: {exc}") + repl_dict = {"exc": str(exc), "file_type": file_type, "path": string if is_path else ""} raise exception(message % repl_dict) return None @@ -652,7 +656,7 @@ async def _log_download_error(resp, msg): log.debug("Redirect history %s: %s; body=%s", get_loggable_url(str(h.url)), h.status, (await h.text())[:1000]) -async def download_file(context, url, abs_filename, session=None, chunk_size=128, auth=None): +async def download_file(context, url, abs_filename, session=None, chunk_size=128, auth=None, expected_content_type=None): """Download a file, async. Args: @@ -663,6 +667,16 @@ async def download_file(context, url, abs_filename, session=None, chunk_size=128 None, use context.session. Defaults to None. chunk_size (int, optional): the chunk size to read from the response at a time. Default is 128. + expected_content_type (str, optional): if set, raise ``DownloadError`` + when the server returns an ``HTML`` response and ``expected_content_type`` + is something other than HTML. Narrow by design — servers vary too + much for a strict match — but catches the common + "error page instead of the JSON/YAML/artifact we asked for" case. + + Raises: + DownloadError: on non-200 status, or an HTML response when + ``expected_content_type`` was not HTML. + Download404: on 404 status. """ session = session or context.session @@ -675,10 +689,15 @@ async def download_file(context, url, abs_filename, session=None, chunk_size=128 async with session.get(url, auth=auth) as resp: if resp.status == 404: await _log_download_error(resp, "404 downloading %(url)s: %(status)s; body=%(body)s") - raise Download404("{} status {}!".format(loggable_url, resp.status)) + raise Download404(f"{loggable_url} status {resp.status}!") elif resp.status != 200: await _log_download_error(resp, "Failed to download %(url)s: %(status)s; body=%(body)s") - raise DownloadError("{} status {} is not 200!".format(loggable_url, resp.status)) + raise DownloadError(f"{loggable_url} status {resp.status} is not 200!") + if expected_content_type: + actual_content_type = (resp.headers.get("Content-Type") or "").split(";", 1)[0].strip().lower() + if actual_content_type == "text/html" and "html" not in expected_content_type.lower(): + await _log_download_error(resp, "HTML response for %(url)s (expected non-HTML): %(status)s; body=%(body)s") + raise DownloadError(f"{loggable_url}: expected Content-Type {expected_content_type!r} but got HTML; treating as an error page") makedirs(parent_dir) with open(abs_filename, "wb") as fd: while True: @@ -729,6 +748,11 @@ def get_parts_of_url_path(url): async def load_json_or_yaml_from_url(context: Context, url: str, path: str, overwrite: bool = True, auth: Optional[str] = None) -> Dict[str, Any]: """Retry a json/yaml file download, load it, then return its data. + Download and parse are combined into a single retry unit: if parsing the + downloaded file fails (e.g. truncated body, an HTML error page, a Cloud + Storage transcoding glitch), the cached file is deleted and the download + is retried. + Args: context (scriptworker.context.Context): the scriptworker context. url (str): the url to download @@ -745,15 +769,42 @@ async def load_json_or_yaml_from_url(context: Context, url: str, path: str, over """ if path.endswith("json"): file_type = "json" + expected_content_type = "application/json" else: file_type = "yaml" + expected_content_type = "application/yaml" - kwargs = {} + download_kwargs = {"expected_content_type": expected_content_type} if auth: - kwargs = {"auth": auth} - if not overwrite or not os.path.exists(path): - await retry_async(download_file, args=(context, url, path), kwargs=kwargs, retry_exceptions=(DownloadError, aiohttp.ClientError, asyncio.TimeoutError)) - return load_json_or_yaml(path, is_path=True, file_type=file_type) + download_kwargs["auth"] = auth + loggable_url = get_loggable_url(url) + + async def _download_and_parse(): + # Pre-existing cache semantics (despite the misleading parameter + # name): ``overwrite=True`` uses an existing file when present; + # ``overwrite=False`` always (re)downloads. + if not overwrite or not os.path.exists(path): + await download_file(context, url, path, **download_kwargs) + try: + return load_json_or_yaml(path, is_path=True, file_type=file_type) + except ScriptWorkerTaskException as exc: + log.warning( + "Failed to parse %s from %s (cached at %s); invalidating cache and retrying: %s", + file_type, + loggable_url, + path, + exc, + ) + try: + os.remove(path) + except OSError: + pass + raise DownloadError(f"parse failure for {loggable_url}: {exc}") + + return await retry_async( + _download_and_parse, + retry_exceptions=(DownloadError, aiohttp.ClientError, asyncio.TimeoutError), + ) # match_url_path_callback {{{1 diff --git a/tests/test_utils.py b/tests/test_utils.py index 0fe45dd0..4c4c4ec9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -443,6 +443,47 @@ async def test_download_file_404(rw_context, fake_session_404, tmpdir, auth): await utils.download_file(rw_context, "url", path, session=fake_session_404, auth=auth) +@pytest.mark.asyncio +async def test_download_file_rejects_html_when_non_html_expected(rw_context, fake_session, tmpdir): + """An HTML response raises DownloadError when a non-HTML content-type was expected.""" + + async def html_request(method, url, *args, **kwargs): + resp = FakeResponse(method, url, status=200) + resp._headers = {"Content-Type": "text/html; charset=utf-8"} + return resp + + fake_session._request = html_request + path = os.path.join(tmpdir, "foo.json") + with pytest.raises(DownloadError, match="HTML"): + await utils.download_file(rw_context, "url", path, session=fake_session, expected_content_type="application/json") + assert not os.path.exists(path) + + +@pytest.mark.asyncio +async def test_load_json_or_yaml_from_url_retries_on_parse_failure(rw_context, mocker, tmpdir): + """A parse failure invalidates the cache and triggers a re-download.""" + path = os.path.join(tmpdir, "out.json") + call_count = {"n": 0} + + async def flaky_download(rw_context, url, abs_filename, session=None, chunk_size=128, auth=None, expected_content_type=None): + call_count["n"] += 1 + if call_count["n"] == 1: + # First attempt: write garbage JSON + with open(abs_filename, "w") as fh: + fh.write("not valid json {") + else: + with open(abs_filename, "w") as fh: + fh.write('{"ok": true}') + + # Neutralize retry_async backoff so the test is fast + mocker.patch.object(utils, "calculate_sleep_time", return_value=0) + mocker.patch.object(utils, "download_file", new=flaky_download) + + result = await utils.load_json_or_yaml_from_url(rw_context, "url", path) + assert result == {"ok": True} + assert call_count["n"] == 2 + + # format_json {{{1 def test_format_json(): expected = "\n".join(["{", ' "a": 1,', ' "b": [', " 4,", " 3,", " 2", " ],", ' "c": {', ' "d": 5', " }", "}"]) @@ -474,7 +515,7 @@ def test_load_json_or_yaml(string, is_path, exception, raises, result): async def test_load_json_or_yaml_from_url(rw_context, mocker, overwrite, file_type, tmpdir): called_with_auth = [] - async def mocked_download_file(rw_context, url, abs_filename, session=None, chunk_size=128, auth=None): + async def mocked_download_file(rw_context, url, abs_filename, session=None, chunk_size=128, auth=None, expected_content_type=None): called_with_auth.append(auth == "someAuth") return @@ -495,7 +536,7 @@ async def mocked_download_file(rw_context, url, abs_filename, session=None, chun async def test_load_json_or_yaml_from_url_auth(rw_context, mocker, overwrite, file_type, tmpdir): called_with_auth = [] - async def mocked_download_file(rw_context, url, abs_filename, session=None, chunk_size=128, auth=None): + async def mocked_download_file(rw_context, url, abs_filename, session=None, chunk_size=128, auth=None, expected_content_type=None): called_with_auth.append(auth == "someAuth") return