Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from dapr.ext.workflow.workflow_state import WorkflowState
from grpc.aio import AioRpcError

from dapr.aio.clients.grpc.interceptors import DaprClientTimeoutInterceptorAsync
from dapr.clients import DaprInternalError
from dapr.clients.http.client import DAPR_API_TOKEN_HEADER
from dapr.conf import settings
Expand Down Expand Up @@ -68,6 +69,7 @@ def __init__(
secure_channel=uri.tls,
log_handler=options.log_handler,
log_formatter=options.log_formatter,
interceptors=[DaprClientTimeoutInterceptorAsync()],
)

async def schedule_new_workflow(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from grpc import RpcError

from dapr.clients import DaprInternalError
from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor
from dapr.clients.http.client import DAPR_API_TOKEN_HEADER
from dapr.conf import settings
from dapr.conf.helpers import GrpcEndpoint
Expand Down Expand Up @@ -70,6 +71,7 @@ def __init__(
secure_channel=uri.tls,
log_handler=options.log_handler,
log_formatter=options.log_formatter,
interceptors=[DaprClientTimeoutInterceptor()],
)

def schedule_new_workflow(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from dapr.ext.workflow.workflow_context import Workflow

from dapr.clients import DaprInternalError
from dapr.clients.grpc.interceptors import DaprClientTimeoutInterceptor
from dapr.clients.http.client import DAPR_API_TOKEN_HEADER
from dapr.conf import settings
from dapr.conf.helpers import GrpcEndpoint
Expand Down Expand Up @@ -73,13 +74,17 @@ def __init__(
raise DaprInternalError(f'{error}') from error

options = self._logger.get_options()
all_interceptors = []
if interceptors:
all_interceptors.extend(interceptors)
all_interceptors.append(DaprClientTimeoutInterceptor())
self.__worker = worker.TaskHubGrpcWorker(
host_address=uri.endpoint,
metadata=metadata,
secure_channel=uri.tls,
log_handler=options.log_handler,
log_formatter=options.log_formatter,
interceptors=interceptors,
interceptors=all_interceptors,
concurrency_options=worker.ConcurrencyOptions(
maximum_concurrent_activity_work_items=maximum_concurrent_activity_work_items,
maximum_concurrent_orchestration_work_items=maximum_concurrent_orchestration_work_items,
Expand Down
20 changes: 19 additions & 1 deletion ext/dapr-ext-workflow/tests/test_workflow_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from typing import Any, Union
from unittest import mock

from durabletask import client
from grpc import RpcError

from dapr.ext.workflow._durabletask import client
from dapr.ext.workflow.dapr_workflow_client import DaprWorkflowClient
from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext
from grpc import RpcError

mock_schedule_result = 'workflow001'
mock_raise_event_result = 'event001'
Expand Down Expand Up @@ -111,6 +113,20 @@ def _inner_get_orchestration_state(self, instance_id, state: client.Orchestratio
)


class WorkflowClientTimeoutInterceptorTest(unittest.TestCase):
def test_timeout_interceptor_is_passed_to_client(self):
with mock.patch('durabletask.client.TaskHubGrpcClient') as mock_client_cls:
DaprWorkflowClient()
mock_client_cls.assert_called_once()
call_kwargs = mock_client_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 1)
from dapr.clients.grpc.interceptors import \
DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)


class WorkflowClientTest(unittest.TestCase):
def mock_client_wf(ctx: DaprWorkflowContext, input):
print(f'{input}')
Expand Down Expand Up @@ -186,3 +202,5 @@ def test_client_functions(self):

actual_purge_result = wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
actual_purge_result = wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
20 changes: 19 additions & 1 deletion ext/dapr-ext-workflow/tests/test_workflow_client_aio.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
from typing import Any, Union
from unittest import mock

from durabletask import client
from grpc.aio import AioRpcError

from dapr.ext.workflow._durabletask import client
from dapr.ext.workflow.aio import DaprWorkflowClient
from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext
from grpc.aio import AioRpcError

mock_schedule_result = 'workflow001'
mock_raise_event_result = 'event001'
Expand Down Expand Up @@ -112,6 +114,20 @@ def _inner_get_orchestration_state(self, instance_id, state: client.Orchestratio
)


class WorkflowClientAioTimeoutInterceptorTest(unittest.IsolatedAsyncioTestCase):
async def test_timeout_interceptor_is_passed_to_client(self):
with mock.patch('durabletask.aio.client.AsyncTaskHubGrpcClient') as mock_client_cls:
DaprWorkflowClient()
mock_client_cls.assert_called_once()
call_kwargs = mock_client_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 1)
from dapr.aio.clients.grpc.interceptors import \
DaprClientTimeoutInterceptorAsync

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptorAsync)


class WorkflowClientAioTest(unittest.IsolatedAsyncioTestCase):
def mock_client_wf(ctx: DaprWorkflowContext, input):
print(f'{input}')
Expand Down Expand Up @@ -190,3 +206,5 @@ async def test_client_functions(self):

actual_purge_result = await wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
actual_purge_result = await wfClient.purge_workflow(instance_id=mock_instance_id)
assert actual_purge_result == mock_purge_result
63 changes: 62 additions & 1 deletion ext/dapr-ext-workflow/tests/test_workflow_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
from typing import List, Optional
from unittest import mock

import grpc
from pydantic import BaseModel, ValidationError

from dapr.ext.workflow.dapr_workflow_context import DaprWorkflowContext
from dapr.ext.workflow.workflow_activity_context import WorkflowActivityContext
from dapr.ext.workflow.workflow_runtime import WorkflowRuntime, alternate_name
from pydantic import BaseModel, ValidationError


class Order(BaseModel):
Expand All @@ -46,6 +48,59 @@ def add_named_activity(self, name: str, fn):
self._activity_fns[name] = fn


class WorkflowRuntimeTimeoutInterceptorTest(unittest.TestCase):
def setUp(self):
listActivities.clear()
listOrchestrators.clear()
self._registry_patch = mock.patch(
'durabletask.worker._Registry', return_value=FakeTaskHubGrpcWorker()
)
self._registry_patch.start()

def tearDown(self):
mock.patch.stopall()

def test_timeout_interceptor_is_prepended(self):
with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls:
WorkflowRuntime()
mock_worker_cls.assert_called_once()
call_kwargs = mock_worker_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 1)
from dapr.clients.grpc.interceptors import \
DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)

def test_timeout_interceptor_with_custom_interceptors(self):
custom_interceptor = mock.MagicMock(spec=grpc.UnaryUnaryClientInterceptor)
with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls:
WorkflowRuntime(interceptors=[custom_interceptor])
call_kwargs = mock_worker_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 2)
from dapr.clients.grpc.interceptors import \
DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)
self.assertIs(interceptors[1], custom_interceptor)

def test_timeout_interceptor_preserves_custom_interceptor_order(self):
custom1 = mock.MagicMock(spec=grpc.UnaryUnaryClientInterceptor)
custom2 = mock.MagicMock(spec=grpc.UnaryStreamClientInterceptor)
with mock.patch('durabletask.worker.TaskHubGrpcWorker') as mock_worker_cls:
WorkflowRuntime(interceptors=[custom1, custom2])
call_kwargs = mock_worker_cls.call_args[1]
interceptors = call_kwargs['interceptors']
self.assertEqual(len(interceptors), 3)
from dapr.clients.grpc.interceptors import \
DaprClientTimeoutInterceptor

self.assertIsInstance(interceptors[0], DaprClientTimeoutInterceptor)
self.assertIs(interceptors[1], custom1)
self.assertIs(interceptors[2], custom2)


class WorkflowRuntimeTest(unittest.TestCase):
def setUp(self):
listActivities.clear()
Expand Down Expand Up @@ -765,3 +820,9 @@ def my_act(ctx, order: Optional[Order]):
wrapper = self.fake_registry._activity_fns['optional_no_default_act']

self.assertIsNone(wrapper(mock.MagicMock(), None))
wrapper = self.fake_registry._activity_fns['optional_no_default_act']

self.assertIsNone(wrapper(mock.MagicMock(), None))
wrapper = self.fake_registry._activity_fns['optional_no_default_act']

self.assertIsNone(wrapper(mock.MagicMock(), None))
Loading