Skip to content

Commit 2eaf8fc

Browse files
committed
fix: address PR review comments
- Track channel ownership in sync/async clients; close() is a no-op when the channel was provided externally via the channel= parameter. - Raise ValueError in GrpcRetryPolicyOptions.__post_init__ when either backoff duration rounds to zero at 9-decimal-place precision, rather than silently emitting a zero-duration string that gRPC rejects.
1 parent 03a44d9 commit 2eaf8fc

3 files changed

Lines changed: 42 additions & 5 deletions

File tree

durabletask/client.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(self, *,
160160
default_version: Optional[str] = None,
161161
payload_store: Optional[PayloadStore] = None):
162162

163+
self._owns_channel = channel is None
163164
if channel is None:
164165
interceptors = prepare_sync_interceptors(metadata, interceptors)
165166
channel = shared.get_grpc_channel(
@@ -175,8 +176,15 @@ def __init__(self, *,
175176
self._payload_store = payload_store
176177

177178
def close(self) -> None:
178-
"""Closes the underlying gRPC channel."""
179-
self._channel.close()
179+
"""Closes the underlying gRPC channel.
180+
181+
Only closes channels created internally. If a pre-configured channel
182+
was passed via the ``channel`` constructor parameter, this method is
183+
a no-op — the caller retains ownership and is responsible for closing
184+
it.
185+
"""
186+
if self._owns_channel:
187+
self._channel.close()
180188

181189
def schedule_new_orchestration(self, orchestrator: Union[task.Orchestrator[TInput, TOutput], str], *,
182190
input: Optional[TInput] = None,
@@ -444,6 +452,7 @@ def __init__(self, *,
444452
default_version: Optional[str] = None,
445453
payload_store: Optional[PayloadStore] = None):
446454

455+
self._owns_channel = channel is None
447456
if channel is None:
448457
interceptors = prepare_async_interceptors(metadata, interceptors)
449458
channel = shared.get_async_grpc_channel(
@@ -459,8 +468,15 @@ def __init__(self, *,
459468
self._payload_store = payload_store
460469

461470
async def close(self) -> None:
462-
"""Closes the underlying gRPC channel."""
463-
await self._channel.close()
471+
"""Closes the underlying gRPC channel.
472+
473+
Only closes channels created internally. If a pre-configured channel
474+
was passed via the ``channel`` constructor parameter, this method is
475+
a no-op — the caller retains ownership and is responsible for closing
476+
it.
477+
"""
478+
if self._owns_channel:
479+
await self._channel.close()
464480

465481
async def __aenter__(self):
466482
return self

durabletask/grpc_options.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,21 @@ def __post_init__(self) -> None:
3131
raise ValueError("max_backoff_seconds must be >= initial_backoff_seconds")
3232
if len(self.retryable_status_codes) == 0:
3333
raise ValueError("retryable_status_codes cannot be empty")
34+
# Validate that backoff values are representable as non-zero gRPC duration strings.
35+
self._format_duration(self.initial_backoff_seconds)
36+
self._format_duration(self.max_backoff_seconds)
3437

3538
@staticmethod
3639
def _format_duration(seconds: float) -> str:
37-
return f"{seconds:.3f}s"
40+
formatted = f"{seconds:.9f}".rstrip('0')
41+
if formatted.endswith('.'):
42+
formatted += '0'
43+
if float(formatted) == 0:
44+
raise ValueError(
45+
f"Duration {seconds!r} rounds to zero; use a value large enough to "
46+
"produce a non-zero gRPC duration string."
47+
)
48+
return f"{formatted}s"
3849

3950
def to_service_config(self) -> dict[str, Any]:
4051
return {

tests/durabletask/test_client.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
import pytest
23
from unittest.mock import ANY, MagicMock, patch
34

45
from durabletask.client import AsyncTaskHubGrpcClient, TaskHubGrpcClient
@@ -62,6 +63,15 @@ def test_get_grpc_channel_with_retry_policy_service_config():
6263
assert retry_policy['retryableStatusCodes'] == ['UNAVAILABLE']
6364

6465

66+
def test_retry_policy_format_duration_raises_on_zero():
67+
with pytest.raises(ValueError, match="rounds to zero"):
68+
GrpcRetryPolicyOptions(
69+
max_attempts=2,
70+
initial_backoff_seconds=1e-15,
71+
max_backoff_seconds=1e-15,
72+
)
73+
74+
6575
def test_get_grpc_channel_default_host_address():
6676
with patch('grpc.insecure_channel') as mock_channel:
6777
get_grpc_channel(None, False, interceptors=INTERCEPTORS)

0 commit comments

Comments
 (0)