Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 80 additions & 0 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,86 @@ The `on_start_query_execution` callback is supported by the following cursor typ
Note: `AsyncCursor` and its variants do not support this callback as they already
return the query ID immediately through their different execution model.

## Query polling callback

PyAthena provides an `on_poll` callback that is invoked once per poll iteration with the
current query execution object, while PyAthena waits for the query to finish. This is useful
for rendering live query progress (state, elapsed time, data scanned) in interactive
environments such as Jupyter notebooks.

The callback is optional (`None` by default), so there is no impact on existing behaviour or
performance when it is not used. It must be a **synchronous** function with the signature
`Callable[[AthenaQueryExecution], None]` (for Spark calculations it receives an
`AthenaCalculationExecutionStatus`). It is invoked on every poll, including the final
iteration that observes the terminal state (`SUCCEEDED`, `FAILED`, or `CANCELLED`).

Unlike `on_start_query_execution`, `on_poll` is configured at the connection or cursor level
only (there is no execute-level override).

### Connection-level callback

```python
from pyathena import connect

def on_poll(query_execution):
print(
f"State: {query_execution.state}, "
f"scanned: {query_execution.data_scanned_in_bytes} bytes"
)

cursor = connect(
s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
region_name="us-west-2",
on_poll=on_poll,
).cursor()

cursor.execute("SELECT * FROM many_rows") # on_poll is invoked on each poll
```

### Cursor-level callback

```python
from pyathena import connect

def on_poll(query_execution):
print(f"State: {query_execution.state}")

conn = connect(
s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
region_name="us-west-2",
)
cursor = conn.cursor(on_poll=on_poll)
cursor.execute("SELECT * FROM many_rows")
```

### Asynchronous cursors

`on_poll` also works with the asynchronous cursors. Because polling runs in a background
thread (or event loop), the callback runs there too, so keep it lightweight and thread-safe:

```python
from pyathena import connect
from pyathena.pandas.async_cursor import AsyncPandasCursor

def on_poll(query_execution):
print(f"State: {query_execution.state}")

conn = connect(
s3_staging_dir="s3://YOUR_S3_BUCKET/path/to/",
region_name="us-west-2",
)
cursor = conn.cursor(AsyncPandasCursor, on_poll=on_poll)
query_id, future = cursor.execute("SELECT * FROM many_rows")
result = future.result()
```

### Supported cursor types

`on_poll` is supported by **all** cursor types, since polling is shared by the base cursor:
the synchronous cursors, the `Async*` cursors, the native-async `Aio*` cursors, and the Spark
cursors. For Spark cursors the callback receives the per-poll
`AthenaCalculationExecutionStatus` rather than an `AthenaQueryExecution`.

## Type hints for complex types

*New in version 3.30.0.*
Expand Down
2 changes: 2 additions & 0 deletions pyathena/aio/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ async def _get_query_execution(self, query_id: str) -> AthenaQueryExecution: #
async def __poll(self, query_id: str) -> AthenaQueryExecution:
while True:
query_execution = await self._get_query_execution(query_id)
if self._on_poll:
self._on_poll(query_execution)
if query_execution.state in [
AthenaQueryExecution.STATE_SUCCEEDED,
AthenaQueryExecution.STATE_FAILED,
Expand Down
2 changes: 2 additions & 0 deletions pyathena/aio/spark/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,8 @@ async def _calculate( # type: ignore[override]
async def __poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution:
while True:
calculation_status = await self._get_calculation_execution_status(query_id)
if self._on_poll:
self._on_poll(calculation_status)
if calculation_status.state in [
AthenaCalculationExecutionStatus.STATE_COMPLETED,
AthenaCalculationExecutionStatus.STATE_FAILED,
Expand Down
3 changes: 0 additions & 3 deletions pyathena/arrow/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def __init__(
unload: bool = False,
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
on_start_query_execution: Callable[[str], None] | None = None,
connect_timeout: float | None = None,
request_timeout: float | None = None,
**kwargs,
Expand All @@ -82,7 +81,6 @@ def __init__(
unload: Enable UNLOAD for high-performance Parquet output.
result_reuse_enable: Enable Athena query result reuse.
result_reuse_minutes: Minutes to reuse cached results.
on_start_query_execution: Callback invoked when query starts.
connect_timeout: Socket connection timeout in seconds for S3 operations.
Defaults to AWS SDK default (typically 1 second) if not specified.
request_timeout: Request timeout in seconds for S3 operations.
Expand Down Expand Up @@ -113,7 +111,6 @@ def __init__(
**kwargs,
)
self._unload = unload
self._on_start_query_execution = on_start_query_execution
self._connect_timeout = connect_timeout
self._request_timeout = request_timeout

Expand Down
18 changes: 18 additions & 0 deletions pyathena/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import sys
import time
from abc import ABCMeta, abstractmethod
from collections.abc import Callable
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, cast

Expand All @@ -27,6 +28,14 @@

_logger = logging.getLogger(__name__)

OnPollCallback = Callable[[AthenaQueryExecution | AthenaCalculationExecutionStatus], None]
"""Type of the optional ``on_poll`` callback.

Invoked once per poll iteration with the current execution object: an
:class:`~pyathena.model.AthenaQueryExecution` for SQL queries, or an
:class:`~pyathena.model.AthenaCalculationExecutionStatus` for Spark calculations.
"""


class CursorIterator(metaclass=ABCMeta):
"""Abstract base class providing iteration and result fetching capabilities for cursors.
Expand Down Expand Up @@ -164,6 +173,8 @@ def __init__(
kill_on_interrupt: bool,
result_reuse_enable: bool,
result_reuse_minutes: int,
on_start_query_execution: Callable[[str], None] | None = None,
on_poll: OnPollCallback | None = None,
**kwargs,
) -> None:
super().__init__()
Expand All @@ -181,6 +192,11 @@ def __init__(
self._kill_on_interrupt = kill_on_interrupt
self._result_reuse_enable = result_reuse_enable
self._result_reuse_minutes = result_reuse_minutes
# ``on_start_query_execution`` is invoked only by cursors whose ``execute()``
# supports it (the synchronous cursors). Async/aio/Spark cursors return the
# query id immediately through their execution model and do not invoke it.
self._on_start_query_execution = on_start_query_execution
self._on_poll = on_poll

@staticmethod
def get_default_converter(unload: bool = False) -> DefaultTypeConverter | Any:
Expand Down Expand Up @@ -558,6 +574,8 @@ def _list_query_executions(
def __poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution:
while True:
query_execution = self._get_query_execution(query_id)
if self._on_poll:
self._on_poll(query_execution)
if query_execution.state in [
AthenaQueryExecution.STATE_SUCCEEDED,
AthenaQueryExecution.STATE_FAILED,
Expand Down
11 changes: 10 additions & 1 deletion pyathena/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from botocore.config import Config

import pyathena
from pyathena.common import BaseCursor, CursorIterator
from pyathena.common import BaseCursor, CursorIterator, OnPollCallback
from pyathena.converter import Converter
from pyathena.cursor import Cursor
from pyathena.error import NotSupportedError, ProgrammingError
Expand Down Expand Up @@ -127,6 +127,7 @@ def __init__(
result_reuse_enable: bool = ...,
result_reuse_minutes: int = ...,
on_start_query_execution: Callable[[str], None] | None = ...,
on_poll: OnPollCallback | None = ...,
**kwargs,
) -> None: ...

Expand Down Expand Up @@ -158,6 +159,7 @@ def __init__(
result_reuse_enable: bool = ...,
result_reuse_minutes: int = ...,
on_start_query_execution: Callable[[str], None] | None = ...,
on_poll: OnPollCallback | None = ...,
**kwargs,
) -> None: ...

Expand Down Expand Up @@ -188,6 +190,7 @@ def __init__(
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
on_start_query_execution: Callable[[str], None] | None = None,
on_poll: OnPollCallback | None = None,
**kwargs,
) -> None:
"""Initialize a new Athena database connection.
Expand Down Expand Up @@ -224,6 +227,10 @@ def __init__(
result_reuse_enable: Enable Athena query result reuse. Defaults to False.
result_reuse_minutes: Minutes to reuse cached results.
on_start_query_execution: Callback function called when query starts.
on_poll: Callback invoked once per poll iteration with the current
execution object (``AthenaQueryExecution``, or
``AthenaCalculationExecutionStatus`` for Spark). Useful for
monitoring live query progress. Defaults to None.
**kwargs: Additional arguments passed to boto3 Session and client.

Raises:
Expand Down Expand Up @@ -331,6 +338,7 @@ def __init__(
self.result_reuse_enable = result_reuse_enable
self.result_reuse_minutes = result_reuse_minutes
self.on_start_query_execution = on_start_query_execution
self.on_poll = on_poll

def _assume_role(
self,
Expand Down Expand Up @@ -555,6 +563,7 @@ def cursor(
on_start_query_execution=kwargs.pop(
"on_start_query_execution", self.on_start_query_execution
),
on_poll=kwargs.pop("on_poll", self.on_poll),
**kwargs,
)

Expand Down
2 changes: 0 additions & 2 deletions pyathena/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ def __init__(
kill_on_interrupt: bool = True,
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
on_start_query_execution: Callable[[str], None] | None = None,
**kwargs,
) -> None:
super().__init__(
Expand All @@ -69,7 +68,6 @@ def __init__(
**kwargs,
)
self._result_set_class = AthenaResultSet
self._on_start_query_execution = on_start_query_execution

@property
def arraysize(self) -> int:
Expand Down
3 changes: 0 additions & 3 deletions pyathena/pandas/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def __init__(
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
auto_optimize_chunksize: bool = False,
on_start_query_execution: Callable[[str], None] | None = None,
**kwargs,
) -> None:
"""Initialize PandasCursor with configuration options.
Expand All @@ -105,7 +104,6 @@ def __init__(
auto_optimize_chunksize: Enable automatic chunksize determination for
large files. Only effective when chunksize is None.
Default: False (no automatic chunking).
on_start_query_execution: Callback for query start events.
**kwargs: Additional arguments passed to pandas.read_csv.
"""
super().__init__(
Expand All @@ -128,7 +126,6 @@ def __init__(
self._cache_type = cache_type
self._max_workers = max_workers
self._auto_optimize_chunksize = auto_optimize_chunksize
self._on_start_query_execution = on_start_query_execution

@staticmethod
def get_default_converter(
Expand Down
3 changes: 0 additions & 3 deletions pyathena/polars/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,6 @@ def __init__(
unload: bool = False,
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
on_start_query_execution: Callable[[str], None] | None = None,
block_size: int | None = None,
cache_type: str | None = None,
max_workers: int = (cpu_count() or 1) * 5,
Expand All @@ -95,7 +94,6 @@ def __init__(
unload: Enable UNLOAD for high-performance Parquet output.
result_reuse_enable: Enable Athena query result reuse.
result_reuse_minutes: Minutes to reuse cached results.
on_start_query_execution: Callback invoked when query starts.
block_size: S3 read block size.
cache_type: S3 caching strategy.
max_workers: Maximum worker threads for parallel S3 operations.
Expand Down Expand Up @@ -123,7 +121,6 @@ def __init__(
**kwargs,
)
self._unload = unload
self._on_start_query_execution = on_start_query_execution
self._block_size = block_size
self._cache_type = cache_type
self._max_workers = max_workers
Expand Down
3 changes: 0 additions & 3 deletions pyathena/s3fs/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __init__(
kill_on_interrupt: bool = True,
result_reuse_enable: bool = False,
result_reuse_minutes: int = CursorIterator.DEFAULT_RESULT_REUSE_MINUTES,
on_start_query_execution: Callable[[str], None] | None = None,
csv_reader: CSVReaderType | None = None,
**kwargs,
) -> None:
Expand All @@ -75,7 +74,6 @@ def __init__(
kill_on_interrupt: Cancel running query on keyboard interrupt.
result_reuse_enable: Enable Athena query result reuse.
result_reuse_minutes: Minutes to reuse cached results.
on_start_query_execution: Callback invoked when query starts.
csv_reader: CSV reader class to use for parsing results.
Use AthenaCSVReader (default) to distinguish between NULL
(unquoted empty) and empty string (quoted empty "").
Expand Down Expand Up @@ -104,7 +102,6 @@ def __init__(
result_reuse_minutes=result_reuse_minutes,
**kwargs,
)
self._on_start_query_execution = on_start_query_execution
self._csv_reader = csv_reader

@staticmethod
Expand Down
2 changes: 2 additions & 0 deletions pyathena/spark/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,8 @@ def _terminate_session(self) -> None:
def __poll(self, query_id: str) -> AthenaQueryExecution | AthenaCalculationExecution:
while True:
calculation_status = self._get_calculation_execution_status(query_id)
if self._on_poll:
self._on_poll(calculation_status)
if calculation_status.state in [
AthenaCalculationExecutionStatus.STATE_COMPLETED,
AthenaCalculationExecutionStatus.STATE_FAILED,
Expand Down
28 changes: 28 additions & 0 deletions tests/pyathena/spark/test_spark_cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@
import time
from concurrent.futures import ThreadPoolExecutor
from random import randint
from unittest.mock import MagicMock, patch

import pytest

from pyathena import DatabaseError, OperationalError
from pyathena.model import AthenaCalculationExecutionStatus
from pyathena.spark.cursor import SparkCursor
from tests import ENV


Expand Down Expand Up @@ -145,3 +147,29 @@ def cancel(c):
)
),
)


def test_spark_on_poll_invoked_each_iteration():
"""on_poll fires once per Spark calculation poll iteration with the status (no AWS)."""
states = [
AthenaCalculationExecutionStatus.STATE_CREATING,
AthenaCalculationExecutionStatus.STATE_RUNNING,
AthenaCalculationExecutionStatus.STATE_COMPLETED,
]
statuses = [MagicMock(state=state) for state in states]
final_execution = MagicMock()
received = []

cursor = SparkCursor.__new__(SparkCursor) # bypass __init__ to avoid AWS calls
cursor._poll_interval = 0
cursor._kill_on_interrupt = False
cursor._on_poll = received.append

with (
patch.object(SparkCursor, "_get_calculation_execution_status", side_effect=statuses),
patch.object(SparkCursor, "_get_calculation_execution", return_value=final_execution),
):
result = cursor._poll("calculation_id")

assert [status.state for status in received] == states
assert result is final_execution
Loading