diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py index fb8367a04..b72be558d 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/aio/dapr_workflow_client.py @@ -26,6 +26,7 @@ from dapr.ext.workflow.workflow_state import WorkflowState from grpc.aio import AioRpcError +from dapr.aio.clients.grpc.interceptors import DaprClientTimeoutInterceptorAsync from dapr.clients import DaprInternalError from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings @@ -68,6 +69,7 @@ def __init__( secure_channel=uri.tls, log_handler=options.log_handler, log_formatter=options.log_formatter, + interceptors=[DaprClientTimeoutInterceptorAsync()], ) async def schedule_new_workflow( diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py index bdb72aca4..d732f7747 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/dapr_workflow_client.py @@ -26,6 +26,7 @@ from grpc import RpcError from dapr.clients import DaprInternalError +from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings from dapr.conf.helpers import GrpcEndpoint @@ -70,6 +71,7 @@ def __init__( secure_channel=uri.tls, log_handler=options.log_handler, log_formatter=options.log_formatter, + interceptors=[DaprClientTimeoutInterceptor()], ) def schedule_new_workflow( diff --git a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py index 11bae78ac..f33622a15 100644 --- a/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py +++ b/ext/dapr-ext-workflow/dapr/ext/workflow/workflow_runtime.py @@ -27,6 +27,7 @@ from dapr.ext.workflow.workflow_context import Workflow from dapr.clients import DaprInternalError +from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor from dapr.clients.http.client import DAPR_API_TOKEN_HEADER from dapr.conf import settings from dapr.conf.helpers import GrpcEndpoint @@ -73,13 +74,17 @@ def __init__( raise DaprInternalError(f'{error}') from error options = self._logger.get_options() + all_interceptors = [] + if interceptors: + all_interceptors.extend(interceptors) + all_interceptors.append(DaprClientTimeoutInterceptor()) self.__worker = worker.TaskHubGrpcWorker( host_address=uri.endpoint, metadata=metadata, secure_channel=uri.tls, log_handler=options.log_handler, log_formatter=options.log_formatter, - interceptors=interceptors, + interceptors=all_interceptors, concurrency_options=worker.ConcurrencyOptions( maximum_concurrent_activity_work_items=maximum_concurrent_activity_work_items, maximum_concurrent_orchestration_work_items=maximum_concurrent_orchestration_work_items, diff --git a/ext/dapr-ext-workflow/tests/test_workflow_client.py b/ext/dapr-ext-workflow/tests/test_workflow_client.py index 26dcbdb61..5ba6c4f00 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_client.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_client.py @@ -18,10 +18,12 @@ from typing import Any, Union from unittest import mock +from durabletask import client +from grpc import RpcError + from dapr.ext.workflow._durabletask import client from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext -from grpc import RpcError mock_schedule_result = 'workflow001' mock_raise_event_result = 'event001' @@ -111,6 +113,20 @@ def _inner_get_orchestration_state(self, instance_id, state: client.Orchestratio ) +class WorkflowClientTimeoutInterceptorTest(unittest.TestCase): + def test_timeout_interceptor_is_passed_to_client(self): + with mock.patch('durabletask.client.TaskHubGrpcClient') as mock_client_cls: + DaprWorkflowClient() + mock_client_cls.assert_called_once() + call_kwargs = mock_client_cls.call_args[1] + interceptors = call_kwargs['interceptors'] + self.assertEqual(len(interceptors), 1) + from dapr.clients.grpc.interceptors import \ + DaprClientTimeoutInterceptor + + self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor) + + class WorkflowClientTest(unittest.TestCase): def mock_client_wf(ctx: DaprWorkflowContext, input): print(f'{input}') @@ -186,3 +202,5 @@ def test_client_functions(self): actual_purge_result = wfClient.purge_workflow(instance_id=mock_instance_id) assert actual_purge_result == mock_purge_result + actual_purge_result = wfClient.purge_workflow(instance_id=mock_instance_id) + assert actual_purge_result == mock_purge_result diff --git a/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py b/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py index 6e5d610f3..7f9081f35 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_client_aio.py @@ -18,10 +18,12 @@ from typing import Any, Union from unittest import mock +from durabletask import client +from grpc.aio import AioRpcError + from dapr.ext.workflow._durabletask import client from dapr.ext.workflow.aio import DaprWorkflowClient from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext -from grpc.aio import AioRpcError mock_schedule_result = 'workflow001' mock_raise_event_result = 'event001' @@ -112,6 +114,20 @@ def _inner_get_orchestration_state(self, instance_id, state: client.Orchestratio ) +class WorkflowClientAioTimeoutInterceptorTest(unittest.IsolatedAsyncioTestCase): + async def test_timeout_interceptor_is_passed_to_client(self): + with mock.patch('durabletask.aio.client.AsyncTaskHubGrpcClient') as mock_client_cls: + DaprWorkflowClient() + mock_client_cls.assert_called_once() + call_kwargs = mock_client_cls.call_args[1] + interceptors = call_kwargs['interceptors'] + self.assertEqual(len(interceptors), 1) + from dapr.aio.clients.grpc.interceptors import \ + DaprClientTimeoutInterceptorAsync + + self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptorAsync) + + class WorkflowClientAioTest(unittest.IsolatedAsyncioTestCase): def mock_client_wf(ctx: DaprWorkflowContext, input): print(f'{input}') @@ -190,3 +206,5 @@ async def test_client_functions(self): actual_purge_result = await wfClient.purge_workflow(instance_id=mock_instance_id) assert actual_purge_result == mock_purge_result + actual_purge_result = await wfClient.purge_workflow(instance_id=mock_instance_id) + assert actual_purge_result == mock_purge_result diff --git a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py index 233bd032f..2810bf2b5 100644 --- a/ext/dapr-ext-workflow/tests/test_workflow_runtime.py +++ b/ext/dapr-ext-workflow/tests/test_workflow_runtime.py @@ -17,10 +17,12 @@ from typing import List, Optional from unittest import mock +import grpc +from pydantic import BaseModel, ValidationError + from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name -from pydantic import BaseModel, ValidationError class Order(BaseModel): @@ -46,6 +48,59 @@ def add_named_activity(self, name: str, fn): self._activity_fns[name] = fn +class WorkflowRuntimeTimeoutInterceptorTest(unittest.TestCase): + def setUp(self): + listActivities.clear() + listOrchestrators.clear() + self._registry_patch = mock.patch( + 'durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker() + ) + self._registry_patch.start() + + def tearDown(self): + mock.patch.stopall() + + def test_timeout_interceptor_is_prepended(self): + with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls: + WorkflowRuntime() + mock_worker_cls.assert_called_once() + call_kwargs = mock_worker_cls.call_args[1] + interceptors = call_kwargs['interceptors'] + self.assertEqual(len(interceptors), 1) + from dapr.clients.grpc.interceptors import \ + DaprClientTimeoutInterceptor + + self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor) + + def test_timeout_interceptor_with_custom_interceptors(self): + custom_interceptor = mock.MagicMock(spec=grpc.UnaryUnaryClientInterceptor) + with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls: + WorkflowRuntime(interceptors=[custom_interceptor]) + call_kwargs = mock_worker_cls.call_args[1] + interceptors = call_kwargs['interceptors'] + self.assertEqual(len(interceptors), 2) + from dapr.clients.grpc.interceptors import \ + DaprClientTimeoutInterceptor + + self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor) + self.assertIs(interceptors[1], custom_interceptor) + + def test_timeout_interceptor_preserves_custom_interceptor_order(self): + custom1 = mock.MagicMock(spec=grpc.UnaryUnaryClientInterceptor) + custom2 = mock.MagicMock(spec=grpc.UnaryStreamClientInterceptor) + with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls: + WorkflowRuntime(interceptors=[custom1, custom2]) + call_kwargs = mock_worker_cls.call_args[1] + interceptors = call_kwargs['interceptors'] + self.assertEqual(len(interceptors), 3) + from dapr.clients.grpc.interceptors import \ + DaprClientTimeoutInterceptor + + self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor) + self.assertIs(interceptors[1], custom1) + self.assertIs(interceptors[2], custom2) + + class WorkflowRuntimeTest(unittest.TestCase): def setUp(self): listActivities.clear() @@ -765,3 +820,9 @@ def my_act(ctx, order: Optional[Order]): wrapper = self.fake_registry._activity_fns['optional_no_default_act'] self.assertIsNone(wrapper(mock.MagicMock(), None)) + wrapper = self.fake_registry._activity_fns['optional_no_default_act'] + + self.assertIsNone(wrapper(mock.MagicMock(), None)) + wrapper = self.fake_registry._activity_fns['optional_no_default_act'] + + self.assertIsNone(wrapper(mock.MagicMock(), None))