diff --git a/docs/news.rst b/docs/news.rst index 8b70a2ae4..ae5b86ac4 100644 --- a/docs/news.rst +++ b/docs/news.rst @@ -11,3 +11,6 @@ gaussdb.0b1 ^^^^^^^^^^^^^ - First public release on PyPI. +- Fixed a crash on ARM64 with some libpq builds when connection attempts fail: + failed-connection diagnostics now use a safe snapshot instead of reading + every libpq connection attribute. diff --git a/gaussdb/gaussdb/errors.py b/gaussdb/gaussdb/errors.py index c05c0c5c2..118e0cd21 100644 --- a/gaussdb/gaussdb/errors.py +++ b/gaussdb/gaussdb/errors.py @@ -22,7 +22,7 @@ from typing import TYPE_CHECKING, Any, Callable, NoReturn from asyncio import CancelledError -from dataclasses import dataclass, field, fields +from dataclasses import dataclass, field from collections.abc import Sequence from .pq.abc import PGconn, PGresult @@ -229,13 +229,42 @@ def send_flush_request(self) -> NoReturn: self._raise() -def finish_pgconn(pgconn: PGconn) -> PGconn: - args = {} - for f in fields(FinishedPGconn): - try: - args[f.name] = getattr(pgconn, f.name) - except Exception: - pass +def finish_pgconn( + pgconn: PGconn, + *, + db: bytes | str | int | None = None, + user: bytes | str | int | None = None, + host: bytes | str | int | None = None, + hostaddr: bytes | str | int | None = None, + port: bytes | str | int | None = None, + options: bytes | str | int | None = None, + error_message: bytes | str | int | None = None, + needs_password: bool | None = None, +) -> PGconn: + def _tobytes(value: bytes | str | int | None) -> bytes | None: + if value is None: + return None + if isinstance(value, bytes): + return value + return str(value).encode("utf-8", "replace") + + args: dict[str, Any] = {} + for name, value in ( + ("db", db), + ("user", user), + ("host", host), + ("hostaddr", hostaddr), + ("port", port), + ("options", options), + ("error_message", error_message), + ): + bvalue = _tobytes(value) + if bvalue is not None: + args[name] = bvalue + + if needs_password is not None: + args["needs_password"] = bool(needs_password) + pgconn.finish() return FinishedPGconn(**args) diff --git a/gaussdb/gaussdb/generators.py b/gaussdb/gaussdb/generators.py index 21d36319f..4ccb9793c 100644 --- a/gaussdb/gaussdb/generators.py +++ b/gaussdb/gaussdb/generators.py @@ -32,6 +32,7 @@ from .pq.abc import PGcancelConn, PGconn, PGresult from .waiting import Ready, Wait from ._cmodule import _gaussdb +from .conninfo import conninfo_to_dict from ._encodings import conninfo_encoding OK = pq.ConnStatus.OK @@ -69,9 +70,24 @@ def _connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[PGconn]: while True: if conn.status == BAD: encoding = conninfo_encoding(conninfo) + message = conn.get_error_message(encoding) + try: + conn_params = conninfo_to_dict(conninfo) + except Exception: + conn_params = {} raise e.OperationalError( - f"connection is bad: {conn.get_error_message(encoding)}", - pgconn=conn, + f"connection is bad: {message}", + pgconn=e.finish_pgconn( + conn, + db=conn_params.get("dbname"), + user=conn_params.get("user"), + host=conn_params.get("host"), + hostaddr=conn_params.get("hostaddr"), + port=conn_params.get("port"), + options=conn_params.get("options"), + error_message=message, + needs_password="password" in message.lower(), + ), ) status = conn.connect_poll() @@ -89,13 +105,41 @@ def _connect(conninfo: str, *, timeout: float = 0.0) -> PQGenConn[PGconn]: break elif status == POLL_FAILED: encoding = conninfo_encoding(conninfo) + message = conn.get_error_message(encoding) + try: + conn_params = conninfo_to_dict(conninfo) + except Exception: + conn_params = {} raise e.OperationalError( - f"connection failed: {conn.get_error_message(encoding)}", - pgconn=e.finish_pgconn(conn), + f"connection failed: {message}", + pgconn=e.finish_pgconn( + conn, + db=conn_params.get("dbname"), + user=conn_params.get("user"), + host=conn_params.get("host"), + hostaddr=conn_params.get("hostaddr"), + port=conn_params.get("port"), + options=conn_params.get("options"), + error_message=message, + needs_password="password" in message.lower(), + ), ) else: + try: + conn_params = conninfo_to_dict(conninfo) + except Exception: + conn_params = {} raise e.InternalError( - f"unexpected poll status: {status}", pgconn=e.finish_pgconn(conn) + f"unexpected poll status: {status}", + pgconn=e.finish_pgconn( + conn, + db=conn_params.get("dbname"), + user=conn_params.get("user"), + host=conn_params.get("host"), + hostaddr=conn_params.get("hostaddr"), + port=conn_params.get("port"), + options=conn_params.get("options"), + ), ) conn.nonblocking = 1 diff --git a/tests/test_errors.py b/tests/test_errors.py index 7f5ac57a0..3a7cc69fb 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -3,6 +3,7 @@ import re import sys import pickle +from typing import cast from weakref import ref import pytest @@ -27,6 +28,36 @@ def test_finishedpgconn(pgconn): pgconn.socket +def test_finish_pgconn_doesnt_read_pgconn_attrs(): + class DummyPGconn: + def __init__(self): + self.finished = False + + def finish(self): + self.finished = True + + @property + def db(self): + raise AssertionError("db should not be read") + + @property + def error_message(self): + raise AssertionError("error_message should not be read") + + dummy = DummyPGconn() + finished = e.finish_pgconn( + cast(pq.abc.PGconn, dummy), + db="nosuchdb", + error_message="failed", + needs_password=True, + ) + + assert dummy.finished + assert finished.db == b"nosuchdb" + assert finished.error_message == b"failed" + assert finished.needs_password + + @pytest.mark.crdb_skip("severity_nonlocalized") def test_error_diag(conn): cur = conn.cursor() diff --git a/tests/test_generators.py b/tests/test_generators.py index 89ee9e134..a0d6555bd 100644 --- a/tests/test_generators.py +++ b/tests/test_generators.py @@ -35,7 +35,7 @@ def test_connect_operationalerror_pgconn(generators, dsn, monkeypatch): pgconn = excinfo.value.pgconn assert pgconn is not None assert pgconn.needs_password - assert b"ERROR: Invalid username/password,login denied.\n" in pgconn.error_message + assert b"Invalid username/password,login denied." in pgconn.error_message assert pgconn.status == pq.ConnStatus.BAD.value assert pgconn.transaction_status == pq.TransactionStatus.UNKNOWN.value assert pgconn.pipeline_status == pq.PipelineStatus.OFF.value