From 7af6317b2306ae882606ce2c3a68b4afe3618af5 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Mon, 8 Jun 2026 23:25:33 +0900 Subject: [PATCH 1/3] Add on_poll callback for real-time query monitoring (#723) Add an optional on_poll callback to BaseCursor/Connection, invoked once per poll iteration with the current execution object. Wired at the cursor/connection level so it propagates via the existing **kwargs chain (no execute() changes; avoids adding to the execute() duplication tracked in #691). Covers all poll loops: BaseCursor.__poll (sync + thread-pool async), SparkBaseCursor.__poll, and the native-async aio loops. For Spark the callback receives the per-poll calculation status, so the signature is Callable[[AthenaQueryExecution | AthenaCalculationExecutionStatus], None]. Co-Authored-By: Claude Opus 4.8 --- pyathena/aio/common.py | 2 + pyathena/aio/spark/cursor.py | 2 + pyathena/common.py | 13 +++++ pyathena/connection.py | 11 +++- pyathena/spark/common.py | 2 + tests/pyathena/spark/test_spark_cursor.py | 28 ++++++++++ tests/pyathena/test_async_cursor.py | 14 +++++ tests/pyathena/test_cursor.py | 63 +++++++++++++++++++++++ 8 files changed, 134 insertions(+), 1 deletion(-) diff --git a/pyathena/aio/common.py b/pyathena/aio/common.py index ec593297..b24eb596 100644 --- a/pyathena/aio/common.py +++ b/pyathena/aio/common.py @@ -83,6 +83,8 @@ async def _get_query_execution(self, query_id: str) -> AthenaQueryExecution: # async def __poll(self, query_id: str) -> AthenaQueryExecution: while True: query_execution = await self._get_query_execution(query_id) + if self._on_poll: + self._on_poll(query_execution) if query_execution.state in [ AthenaQueryExecution.STATE_SUCCEEDED, AthenaQueryExecution.STATE_FAILED, diff --git a/pyathena/aio/spark/cursor.py b/pyathena/aio/spark/cursor.py index c414eba1..68f537f7 100644 --- a/pyathena/aio/spark/cursor.py +++ b/pyathena/aio/spark/cursor.py @@ -128,6 +128,8 @@ async def _calculate( # type: ignore[override] async def __poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution: while True: calculation_status = await self._get_calculation_execution_status(query_id) + if self._on_poll: + self._on_poll(calculation_status) if calculation_status.state in [ AthenaCalculationExecutionStatus.STATE_COMPLETED, AthenaCalculationExecutionStatus.STATE_FAILED, diff --git a/pyathena/common.py b/pyathena/common.py index 66ac4fba..db570eeb 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -4,6 +4,7 @@ import sys import time from abc import ABCMeta, abstractmethod +from collections.abc import Callable from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, cast @@ -27,6 +28,14 @@ _logger = logging.getLogger(__name__) +OnPollCallback = Callable[[AthenaQueryExecution | AthenaCalculationExecutionStatus], None] +"""Type of the optional ``on_poll`` callback. + +Invoked once per poll iteration with the current execution object: an +:class:`~pyathena.model.AthenaQueryExecution` for SQL queries, or an +:class:`~pyathena.model.AthenaCalculationExecutionStatus` for Spark calculations. +""" + class CursorIterator(metaclass=ABCMeta): """Abstract base class providing iteration and result fetching capabilities for cursors. @@ -164,6 +173,7 @@ def __init__( kill_on_interrupt: bool, result_reuse_enable: bool, result_reuse_minutes: int, + on_poll: OnPollCallback | None = None, **kwargs, ) -> None: super().__init__() @@ -181,6 +191,7 @@ def __init__( self._kill_on_interrupt = kill_on_interrupt self._result_reuse_enable = result_reuse_enable self._result_reuse_minutes = result_reuse_minutes + self._on_poll = on_poll @staticmethod def get_default_converter(unload: bool = False) -> DefaultTypeConverter | Any: @@ -558,6 +569,8 @@ def _list_query_executions( def __poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution: while True: query_execution = self._get_query_execution(query_id) + if self._on_poll: + self._on_poll(query_execution) if query_execution.state in [ AthenaQueryExecution.STATE_SUCCEEDED, AthenaQueryExecution.STATE_FAILED, diff --git a/pyathena/connection.py b/pyathena/connection.py index 65a927f5..ae9c3717 100644 --- a/pyathena/connection.py +++ b/pyathena/connection.py @@ -18,7 +18,7 @@ from botocore.config import Config import pyathena -from pyathena.common import BaseCursor, CursorIterator +from pyathena.common import BaseCursor, CursorIterator, OnPollCallback from pyathena.converter import Converter from pyathena.cursor import Cursor from pyathena.error import NotSupportedError, ProgrammingError @@ -127,6 +127,7 @@ def __init__( result_reuse_enable: bool = ..., result_reuse_minutes: int = ..., on_start_query_execution: Callable[[str], None] | None = ..., + on_poll: OnPollCallback | None = ..., **kwargs, ) -> None: ... @@ -158,6 +159,7 @@ def __init__( result_reuse_enable: bool = ..., result_reuse_minutes: int = ..., on_start_query_execution: Callable[[str], None] | None = ..., + on_poll: OnPollCallback | None = ..., **kwargs, ) -> None: ... @@ -188,6 +190,7 @@ def __init__( result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, on_start_query_execution: Callable[[str], None] | None = None, + on_poll: OnPollCallback | None = None, **kwargs, ) -> None: """Initialize a new Athena database connection. @@ -224,6 +227,10 @@ def __init__( result_reuse_enable: Enable Athena query result reuse. Defaults to False. result_reuse_minutes: Minutes to reuse cached results. on_start_query_execution: Callback function called when query starts. + on_poll: Callback invoked once per poll iteration with the current + execution object (``AthenaQueryExecution``, or + ``AthenaCalculationExecutionStatus`` for Spark). Useful for + monitoring live query progress. Defaults to None. **kwargs: Additional arguments passed to boto3 Session and client. Raises: @@ -331,6 +338,7 @@ def __init__( self.result_reuse_enable = result_reuse_enable self.result_reuse_minutes = result_reuse_minutes self.on_start_query_execution = on_start_query_execution + self.on_poll = on_poll def _assume_role( self, @@ -555,6 +563,7 @@ def cursor( on_start_query_execution=kwargs.pop( "on_start_query_execution", self.on_start_query_execution ), + on_poll=kwargs.pop("on_poll", self.on_poll), **kwargs, ) diff --git a/pyathena/spark/common.py b/pyathena/spark/common.py index 959ad014..836ab56a 100644 --- a/pyathena/spark/common.py +++ b/pyathena/spark/common.py @@ -201,6 +201,8 @@ def _terminate_session(self) -> None: def __poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution: while True: calculation_status = self._get_calculation_execution_status(query_id) + if self._on_poll: + self._on_poll(calculation_status) if calculation_status.state in [ AthenaCalculationExecutionStatus.STATE_COMPLETED, AthenaCalculationExecutionStatus.STATE_FAILED, diff --git a/tests/pyathena/spark/test_spark_cursor.py b/tests/pyathena/spark/test_spark_cursor.py index 12723ff1..a1da9d51 100644 --- a/tests/pyathena/spark/test_spark_cursor.py +++ b/tests/pyathena/spark/test_spark_cursor.py @@ -2,11 +2,13 @@ import time from concurrent.futures import ThreadPoolExecutor from random import randint +from unittest.mock import MagicMock, patch import pytest from pyathena import DatabaseError, OperationalError from pyathena.model import AthenaCalculationExecutionStatus +from pyathena.spark.cursor import SparkCursor from tests import ENV @@ -145,3 +147,29 @@ def cancel(c): ) ), ) + + +def test_spark_on_poll_invoked_each_iteration(): + """on_poll fires once per Spark calculation poll iteration with the status (no AWS).""" + states = [ + AthenaCalculationExecutionStatus.STATE_CREATING, + AthenaCalculationExecutionStatus.STATE_RUNNING, + AthenaCalculationExecutionStatus.STATE_COMPLETED, + ] + statuses = [MagicMock(state=state) for state in states] + final_execution = MagicMock() + received = [] + + cursor = SparkCursor.__new__(SparkCursor) # bypass __init__ to avoid AWS calls + cursor._poll_interval = 0 + cursor._kill_on_interrupt = False + cursor._on_poll = received.append + + with ( + patch.object(SparkCursor, "_get_calculation_execution_status", side_effect=statuses), + patch.object(SparkCursor, "_get_calculation_execution", return_value=final_execution), + ): + result = cursor._poll("calculation_id") + + assert [status.state for status in received] == states + assert result is final_execution diff --git a/tests/pyathena/test_async_cursor.py b/tests/pyathena/test_async_cursor.py index f04d15f6..1cf4d0d9 100644 --- a/tests/pyathena/test_async_cursor.py +++ b/tests/pyathena/test_async_cursor.py @@ -172,6 +172,20 @@ def test_poll(self, async_cursor): AthenaQueryExecution.STATE_CANCELLED, ] + def test_on_poll(self): + """on_poll fires during async polling (issue #723 example).""" + states = [] + + with contextlib.closing( + connect(on_poll=lambda execution: states.append(execution.state)) + ) as conn: + cursor = conn.cursor(AsyncCursor) + _, future = cursor.execute("SELECT 1") + future.result() + + assert len(states) >= 1 + assert states[-1] == AthenaQueryExecution.STATE_SUCCEEDED + def test_bad_query(self, async_cursor): query_id, future = async_cursor.execute( "SELECT does_not_exist FROM this_really_does_not_exist" diff --git a/tests/pyathena/test_cursor.py b/tests/pyathena/test_cursor.py index 3dd15fc2..15155180 100644 --- a/tests/pyathena/test_cursor.py +++ b/tests/pyathena/test_cursor.py @@ -11,6 +11,7 @@ from datetime import date, datetime, timezone from decimal import Decimal from random import randint +from unittest.mock import MagicMock, patch import pytest @@ -934,6 +935,68 @@ def verify_query_id(): row = cursor.fetchone() assert row == (5,) + def test_on_poll_invoked_each_iteration(self): + """on_poll fires once per poll iteration with the current execution (no AWS).""" + states = [ + AthenaQueryExecution.STATE_QUEUED, + AthenaQueryExecution.STATE_RUNNING, + AthenaQueryExecution.STATE_SUCCEEDED, + ] + executions = [MagicMock(state=state) for state in states] + received = [] + + cursor = Cursor.__new__(Cursor) # bypass __init__ to avoid AWS calls + cursor._poll_interval = 0 + cursor._kill_on_interrupt = False + cursor._on_poll = received.append + + with patch.object(Cursor, "_get_query_execution", side_effect=executions): + result = cursor._poll("query_id") + + # Callback received every state in order, including the terminal one + assert [execution.state for execution in received] == states + assert result is executions[-1] + + def test_on_poll_none_is_noop(self): + """A None on_poll callback does not affect polling (no AWS).""" + execution = MagicMock(state=AthenaQueryExecution.STATE_SUCCEEDED) + + cursor = Cursor.__new__(Cursor) # bypass __init__ to avoid AWS calls + cursor._poll_interval = 0 + cursor._kill_on_interrupt = False + cursor._on_poll = None + + with patch.object(Cursor, "_get_query_execution", return_value=execution): + result = cursor._poll("query_id") + + assert result is execution + + def test_on_poll_connection_level(self): + """Connection-level on_poll fires during query execution.""" + states = [] + + with contextlib.closing( + connect(on_poll=lambda execution: states.append(execution.state)) + ) as conn: + cursor = conn.cursor() + cursor.execute("SELECT 1") + + assert len(states) >= 1 + assert states[-1] == AthenaQueryExecution.STATE_SUCCEEDED + assert cursor.fetchone() == (1,) + + def test_on_poll_cursor_level(self): + """on_poll passed via cursor() fires during query execution.""" + states = [] + + with contextlib.closing(connect()) as conn: + cursor = conn.cursor(on_poll=lambda execution: states.append(execution.state)) + cursor.execute("SELECT 1") + + assert len(states) >= 1 + assert states[-1] == AthenaQueryExecution.STATE_SUCCEEDED + assert cursor.fetchone() == (1,) + def test_null_vs_empty_string(self, cursor): """ Default Cursor should properly distinguish NULL from empty string. From f4b8f6477b3f78854c6db61425bb8a49dcac7bbc Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Mon, 8 Jun 2026 23:45:45 +0900 Subject: [PATCH 2/3] Document on_poll polling callback Add a "Query polling callback" section to docs/usage.md covering on_poll: connection- and cursor-level configuration, the synchronous callback contract, per-iteration invocation including the terminal state, async cursor usage, and the Spark calculation-status payload. Co-Authored-By: Claude Opus 4.8 --- docs/usage.md | 80 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/docs/usage.md b/docs/usage.md index 88d19e6b..7b2a84d4 100644 --- a/docs/usage.md +++ b/docs/usage.md @@ -427,6 +427,86 @@ The `on_start_query_execution` callback is supported by the following cursor typ Note: `AsyncCursor` and its variants do not support this callback as they already return the query ID immediately through their different execution model. +## Query polling callback + +PyAthena provides an `on_poll` callback that is invoked once per poll iteration with the +current query execution object, while PyAthena waits for the query to finish. This is useful +for rendering live query progress (state, elapsed time, data scanned) in interactive +environments such as Jupyter notebooks. + +The callback is optional (`None` by default), so there is no impact on existing behaviour or +performance when it is not used. It must be a **synchronous** function with the signature +`Callable[[AthenaQueryExecution], None]` (for Spark calculations it receives an +`AthenaCalculationExecutionStatus`). It is invoked on every poll, including the final +iteration that observes the terminal state (`SUCCEEDED`, `FAILED`, or `CANCELLED`). + +Unlike `on_start_query_execution`, `on_poll` is configured at the connection or cursor level +only (there is no execute-level override). + +### Connection-level callback + +```python +from pyathena import connect + +def on_poll(query_execution): + print( + f"State: {query_execution.state}, " + f"scanned: {query_execution.data_scanned_in_bytes} bytes" + ) + +cursor = connect( + s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", + on_poll=on_poll, +).cursor() + +cursor.execute("SELECT * FROM many_rows") # on_poll is invoked on each poll +``` + +### Cursor-level callback + +```python +from pyathena import connect + +def on_poll(query_execution): + print(f"State: {query_execution.state}") + +conn = connect( + s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", +) +cursor = conn.cursor(on_poll=on_poll) +cursor.execute("SELECT * FROM many_rows") +``` + +### Asynchronous cursors + +`on_poll` also works with the asynchronous cursors. Because polling runs in a background +thread (or event loop), the callback runs there too, so keep it lightweight and thread-safe: + +```python +from pyathena import connect +from pyathena.pandas.async_cursor import AsyncPandasCursor + +def on_poll(query_execution): + print(f"State: {query_execution.state}") + +conn = connect( + s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/", + region_name="us-west-2", +) +cursor = conn.cursor(AsyncPandasCursor, on_poll=on_poll) +query_id, future = cursor.execute("SELECT * FROM many_rows") +result = future.result() +``` + +### Supported cursor types + +`on_poll` is supported by **all** cursor types, since polling is shared by the base cursor: +the synchronous cursors, the `Async*` cursors, the native-async `Aio*` cursors, and the Spark +cursors. For Spark cursors the callback receives the per-poll +`AthenaCalculationExecutionStatus` rather than an `AthenaQueryExecution`. + ## Type hints for complex types *New in version 3.30.0.* From 97340538fbe5edf29cf6f99babe4c3ec38453592 Mon Sep 17 00:00:00 2001 From: laughingman7743 Date: Tue, 9 Jun 2026 12:17:54 +0900 Subject: [PATCH 3/3] Centralize on_start_query_execution storage in BaseCursor Move the on_start_query_execution field from the five synchronous cursors (Cursor, PandasCursor, ArrowCursor, PolarsCursor, S3FSCursor) into BaseCursor.__init__, mirroring on_poll, so both connection-level callbacks live in one place. The per-execute() override and its invocation stay in each synchronous cursor (the broader execute() kwargs consolidation is tracked in #691). Async/aio/Spark cursors inherit the field but do not invoke it, as they return the query id immediately through their execution model. Behavior is unchanged; a clarifying comment documents this. Co-Authored-By: Claude Opus 4.8 --- pyathena/arrow/cursor.py | 3 --- pyathena/common.py | 5 +++++ pyathena/cursor.py | 2 -- pyathena/pandas/cursor.py | 3 --- pyathena/polars/cursor.py | 3 --- pyathena/s3fs/cursor.py | 3 --- 6 files changed, 5 insertions(+), 14 deletions(-) diff --git a/pyathena/arrow/cursor.py b/pyathena/arrow/cursor.py index 9d3d879f..b4775df6 100644 --- a/pyathena/arrow/cursor.py +++ b/pyathena/arrow/cursor.py @@ -63,7 +63,6 @@ def __init__( unload: bool = False, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - on_start_query_execution: Callable[[str], None] | None = None, connect_timeout: float | None = None, request_timeout: float | None = None, **kwargs, @@ -82,7 +81,6 @@ def __init__( unload: Enable UNLOAD for high-performance Parquet output. result_reuse_enable: Enable Athena query result reuse. result_reuse_minutes: Minutes to reuse cached results. - on_start_query_execution: Callback invoked when query starts. connect_timeout: Socket connection timeout in seconds for S3 operations. Defaults to AWS SDK default (typically 1 second) if not specified. request_timeout: Request timeout in seconds for S3 operations. @@ -113,7 +111,6 @@ def __init__( **kwargs, ) self._unload = unload - self._on_start_query_execution = on_start_query_execution self._connect_timeout = connect_timeout self._request_timeout = request_timeout diff --git a/pyathena/common.py b/pyathena/common.py index db570eeb..04d919b4 100644 --- a/pyathena/common.py +++ b/pyathena/common.py @@ -173,6 +173,7 @@ def __init__( kill_on_interrupt: bool, result_reuse_enable: bool, result_reuse_minutes: int, + on_start_query_execution: Callable[[str], None] | None = None, on_poll: OnPollCallback | None = None, **kwargs, ) -> None: @@ -191,6 +192,10 @@ def __init__( self._kill_on_interrupt = kill_on_interrupt self._result_reuse_enable = result_reuse_enable self._result_reuse_minutes = result_reuse_minutes + # ``on_start_query_execution`` is invoked only by cursors whose ``execute()`` + # supports it (the synchronous cursors). Async/aio/Spark cursors return the + # query id immediately through their execution model and do not invoke it. + self._on_start_query_execution = on_start_query_execution self._on_poll = on_poll @staticmethod diff --git a/pyathena/cursor.py b/pyathena/cursor.py index d113b387..a9438d0a 100644 --- a/pyathena/cursor.py +++ b/pyathena/cursor.py @@ -52,7 +52,6 @@ def __init__( kill_on_interrupt: bool = True, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - on_start_query_execution: Callable[[str], None] | None = None, **kwargs, ) -> None: super().__init__( @@ -69,7 +68,6 @@ def __init__( **kwargs, ) self._result_set_class = AthenaResultSet - self._on_start_query_execution = on_start_query_execution @property def arraysize(self) -> int: diff --git a/pyathena/pandas/cursor.py b/pyathena/pandas/cursor.py index 22a7d8ac..3b55e7b6 100644 --- a/pyathena/pandas/cursor.py +++ b/pyathena/pandas/cursor.py @@ -79,7 +79,6 @@ def __init__( result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, auto_optimize_chunksize: bool = False, - on_start_query_execution: Callable[[str], None] | None = None, **kwargs, ) -> None: """Initialize PandasCursor with configuration options. @@ -105,7 +104,6 @@ def __init__( auto_optimize_chunksize: Enable automatic chunksize determination for large files. Only effective when chunksize is None. Default: False (no automatic chunking). - on_start_query_execution: Callback for query start events. **kwargs: Additional arguments passed to pandas.read_csv. """ super().__init__( @@ -128,7 +126,6 @@ def __init__( self._cache_type = cache_type self._max_workers = max_workers self._auto_optimize_chunksize = auto_optimize_chunksize - self._on_start_query_execution = on_start_query_execution @staticmethod def get_default_converter( diff --git a/pyathena/polars/cursor.py b/pyathena/polars/cursor.py index efc738c4..12ab6b25 100644 --- a/pyathena/polars/cursor.py +++ b/pyathena/polars/cursor.py @@ -74,7 +74,6 @@ def __init__( unload: bool = False, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - on_start_query_execution: Callable[[str], None] | None = None, block_size: int | None = None, cache_type: str | None = None, max_workers: int = (cpu_count() or 1) * 5, @@ -95,7 +94,6 @@ def __init__( unload: Enable UNLOAD for high-performance Parquet output. result_reuse_enable: Enable Athena query result reuse. result_reuse_minutes: Minutes to reuse cached results. - on_start_query_execution: Callback invoked when query starts. block_size: S3 read block size. cache_type: S3 caching strategy. max_workers: Maximum worker threads for parallel S3 operations. @@ -123,7 +121,6 @@ def __init__( **kwargs, ) self._unload = unload - self._on_start_query_execution = on_start_query_execution self._block_size = block_size self._cache_type = cache_type self._max_workers = max_workers diff --git a/pyathena/s3fs/cursor.py b/pyathena/s3fs/cursor.py index dfc5dd5e..f0521f7f 100644 --- a/pyathena/s3fs/cursor.py +++ b/pyathena/s3fs/cursor.py @@ -58,7 +58,6 @@ def __init__( kill_on_interrupt: bool = True, result_reuse_enable: bool = False, result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES, - on_start_query_execution: Callable[[str], None] | None = None, csv_reader: CSVReaderType | None = None, **kwargs, ) -> None: @@ -75,7 +74,6 @@ def __init__( kill_on_interrupt: Cancel running query on keyboard interrupt. result_reuse_enable: Enable Athena query result reuse. result_reuse_minutes: Minutes to reuse cached results. - on_start_query_execution: Callback invoked when query starts. csv_reader: CSV reader class to use for parsing results. Use AthenaCSVReader (default) to distinguish between NULL (unquoted empty) and empty string (quoted empty ""). @@ -104,7 +102,6 @@ def __init__( result_reuse_minutes=result_reuse_minutes, **kwargs, ) - self._on_start_query_execution = on_start_query_execution self._csv_reader = csv_reader @staticmethod