diff --git a/snowflake_sql_api/testing.py b/snowflake_sql_api/testing.py index 381316c..bacd7bf 100644 --- a/snowflake_sql_api/testing.py +++ b/snowflake_sql_api/testing.py @@ -599,7 +599,7 @@ def make_client( from .client import SnowflakeClient _reject_managed_kwargs("make_client", kwargs) - return SnowflakeClient( + client = SnowflakeClient( account, user, private_key=_throwaway_key(), @@ -607,6 +607,8 @@ def make_client( poll_interval=poll_interval, **kwargs, ) + client._transport._owns_client = True + return client def make_async_client( @@ -621,7 +623,7 @@ def make_async_client( from .aclient import AsyncSnowflakeClient _reject_managed_kwargs("make_async_client", kwargs) - return AsyncSnowflakeClient( + client = AsyncSnowflakeClient( account, user, private_key=_throwaway_key(), @@ -629,6 +631,8 @@ def make_async_client( poll_interval=poll_interval, **kwargs, ) + client._transport._owns_client = True + return client # --------------------------------------------------------------------------- diff --git a/tests/conftest.py b/tests/conftest.py index 3b16a20..506f16b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -25,6 +25,12 @@ PASSPHRASE = b"correct horse battery staple" +def pytest_configure(config: pytest.Config) -> None: + """Load shipped testing fixtures when pytest11 entry points are unavailable.""" + if not config.pluginmanager.hasplugin("snowflake_sql_api"): + config.pluginmanager.import_plugin("snowflake_sql_api.testing") + + @pytest.fixture def fake_account() -> str: """A region-suffixed account locator for auth-normalization tests.""" diff --git a/tests/test_testing.py b/tests/test_testing.py index aa7f047..c15aed0 100644 --- a/tests/test_testing.py +++ b/tests/test_testing.py @@ -46,6 +46,7 @@ def test_make_client_returns_real_client() -> None: assert isinstance(client, SnowflakeClient) assert client.query_scalar("SELECT 1") == 1 client.close() + assert client._transport._client.is_closed def test_make_client_rejects_managed_kwargs() -> None: @@ -304,6 +305,7 @@ async def test_async_query() -> None: assert isinstance(client, AsyncSnowflakeClient) assert await client.query("SELECT id FROM t") == [{"ID": 1}, {"ID": 2}] await client.aclose() + assert client._transport._client.is_closed async def test_async_submit_and_result() -> None: