diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2b8ec6b..dfa7a86 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -2,11 +2,11 @@ name: CI on: push: - branches: [main] - tags: [v*.*.*] + branches: [ main ] + tags: [ v*.*.* ] pull_request: - branches: [ "main" ] + branches: [ main ] types: - synchronize - opened diff --git a/hello/advertizer.py b/hello/advertizer.py index c95629e..3aea6af 100644 --- a/hello/advertizer.py +++ b/hello/advertizer.py @@ -49,12 +49,15 @@ def stop(self) -> None: self._sender.stop() def advertise(self, info: ServiceInfo | None = None) -> None: + if info: + self._info = info + if self._group: - if info: - self._info = info if self._info: self._sender.send(self._info) log.info('Service advertised', service=self._info, group=self._group) + else: + log.warning('Cannot advertise service, no service info provided', group=self._group) else: log.warning('Cannot advertise service, advertizer not started', service=info) @@ -72,21 +75,22 @@ def start(self, group: Group, info: ServiceInfo | None = None) -> None: self._receiver.register(self._handle_message) def stop(self) -> None: - super().stop() + self._receiver.deregister(self._handle_message) self._receiver.stop() + super().stop() def _handle_message(self, message: dict[str, Any]) -> None: if self._info: try: query = ServiceQuery(**message) - log.debug('Query received', group=self._group, query=query) - self._handle_query(query, self._info) + matcher = ServiceMatcher(query) + log.debug('Service query received', group=self._group, query=query) + self._handle_query(matcher, self._info) except Exception as error: - log.warning('Invalid query message received', group=self._group, received=message, error=error) + log.warning('Invalid service query received', group=self._group, received=message, error=error) - def _handle_query(self, query: ServiceQuery, info: ServiceInfo) -> None: - matcher = ServiceMatcher(query) - if matcher and matcher.matches(info): + def _handle_query(self, matcher: ServiceMatcher, info: ServiceInfo) -> None: + if matcher.matches(info): delay = round(self._max_delay * random.random(), 3) log.info('Responding to query', group=self._group, query=matcher.query, service=info, delay=delay) time.sleep(delay) diff --git a/hello/discoverer.py b/hello/discoverer.py index ea0cf24..e99e62f 100644 --- a/hello/discoverer.py +++ b/hello/discoverer.py @@ -2,6 +2,7 @@ # SPDX-FileCopyrightText: 2024 Attila Gombos # SPDX-License-Identifier: MIT +from concurrent.futures import ThreadPoolExecutor from dataclasses import dataclass from enum import Enum from typing import Any, Protocol @@ -22,6 +23,8 @@ class DiscoveryEventType(Enum): @dataclass class DiscoveryEvent: + group: Group + query: ServiceQuery service: ServiceInfo type: DiscoveryEventType @@ -41,28 +44,26 @@ def stop(self) -> None: def discover(self, query: ServiceQuery | None = None) -> None: raise NotImplementedError() - def get_services(self) -> dict[UUID, ServiceInfo]: - raise NotImplementedError() - def register(self, handler: OnDiscoveryEvent) -> None: raise NotImplementedError() def deregister(self, handler: OnDiscoveryEvent) -> None: raise NotImplementedError() - def get_handlers(self) -> list[OnDiscoveryEvent]: + def get_services(self) -> dict[UUID, ServiceInfo]: raise NotImplementedError() class DefaultDiscoverer(Discoverer): - def __init__(self, sender: Sender, receiver: Receiver) -> None: + def __init__(self, sender: Sender, receiver: Receiver, max_workers: int = 8) -> None: self._sender = sender self._receiver = receiver self._group: Group | None = None self._matcher: ServiceMatcher | None = None - self._cache: dict[UUID, ServiceInfo] = {} + self._services: dict[UUID, ServiceInfo] = {} self._handlers: list[OnDiscoveryEvent] = [] + self._handler_executor = ThreadPoolExecutor(max_workers=max_workers) def __enter__(self) -> Discoverer: return self @@ -80,65 +81,73 @@ def start(self, group: Group, query: ServiceQuery | None = None) -> None: def stop(self) -> None: self._group = None + self._matcher = None self._sender.stop() + self._receiver.deregister(self._handle_message) self._receiver.stop() def discover(self, query: ServiceQuery | None = None) -> None: + if query: + self._matcher = ServiceMatcher(query) + if self._group: - if query: - self._matcher = ServiceMatcher(query) if self._matcher: self._sender.send(self._matcher.query) - log.info('Service discovery initiated', query=self._matcher.query, group=self._group) + log.info('Service discovery initiated', group=self._group, query=self._matcher.query) + else: + log.warning('Cannot discover services, no query provided', group=self._group) else: log.warning('Cannot discover services, discoverer not started', query=query) - def get_services(self) -> dict[UUID, ServiceInfo]: - return self._cache.copy() - def register(self, handler: OnDiscoveryEvent) -> None: self._handlers.append(handler) def deregister(self, handler: OnDiscoveryEvent) -> None: self._handlers.remove(handler) - def get_handlers(self) -> list[OnDiscoveryEvent]: - return self._handlers.copy() + def get_services(self) -> dict[UUID, ServiceInfo]: + return self._services.copy() def _handle_message(self, message: dict[str, Any]) -> None: - try: - service = ServiceInfo(UUID(message['uuid']), message['name'], message['role'], message.get('urls', {})) - self._handle_service(service) - except Exception as error: - log.warn('Failed to handle received message', data=message, error=error) + if self._group and self._matcher: + try: + service = ServiceInfo(UUID(message['uuid']), message['name'], message['role'], message.get('urls', {})) + log.debug('Service info received', service=service, group=self._group) + self._handle_service(service, self._group, self._matcher) + except Exception as error: + log.warn('Invalid service info received', group=self._group, data=message, error=error) - def _handle_service(self, service: ServiceInfo) -> None: - if self._matcher and self._matcher.matches(service): - cached = self._cache.get(service.uuid) + def _handle_service(self, service: ServiceInfo, group: Group, matcher: ServiceMatcher) -> None: + if matcher.matches(service): + stored = self._services.get(service.uuid) - if event := self._create_event(cached, service): + if event := self._create_event(group, matcher, stored, service): self._handle_event(event) - def _create_event(self, cached: ServiceInfo | None, service: ServiceInfo) -> DiscoveryEvent | None: - if cached: - if cached != service: - log.info('Service updated', old_service=cached, new_service=service) - return DiscoveryEvent(service, DiscoveryEventType.UPDATED) + def _create_event(self, group: Group, matcher: ServiceMatcher, + stored: ServiceInfo | None, service: ServiceInfo) -> DiscoveryEvent | None: + if stored: + if stored != service: + log.info('Service updated', group=group, old_service=stored, new_service=service) + return DiscoveryEvent(group, matcher.query, service, DiscoveryEventType.UPDATED) else: - log.debug('Service unchanged', service=service) + log.debug('Service unchanged', group=group, service=service) return None else: - log.info('Service discovered', service=service) - return DiscoveryEvent(service, DiscoveryEventType.DISCOVERED) + log.info('New service discovered', group=group, service=service) + return DiscoveryEvent(group, matcher.query, service, DiscoveryEventType.DISCOVERED) def _handle_event(self, event: DiscoveryEvent) -> None: - service = event.service - self._cache[service.uuid] = service - for callback in self._handlers: - try: - callback(event) - except Exception as error: - log.warn('Error in event handler execution', event=event, error=error) + self._services[event.service.uuid] = event.service + + for handler in self._handlers: + self._handler_executor.submit(self._execute_handler, handler, event) + + def _execute_handler(self, handler: OnDiscoveryEvent, event: DiscoveryEvent) -> None: + try: + handler(event) + except Exception as error: + log.warn('Error in event handler execution', event=event, error=error) class ScheduledDiscoverer(DefaultScheduler[ServiceQuery], Discoverer): @@ -172,8 +181,5 @@ def register(self, handler: OnDiscoveryEvent) -> None: def deregister(self, handler: OnDiscoveryEvent) -> None: self._discoverer.deregister(handler) - def get_handlers(self) -> list[OnDiscoveryEvent]: - return self._discoverer.get_handlers() - def _execute(self, query: ServiceQuery | None = None) -> None: self.discover(query) diff --git a/hello/receiver.py b/hello/receiver.py index c032bc9..8c045ab 100644 --- a/hello/receiver.py +++ b/hello/receiver.py @@ -31,9 +31,6 @@ def register(self, handler: OnMessage) -> None: def deregister(self, handler: OnMessage) -> None: raise NotImplementedError() - def get_handlers(self) -> list[OnMessage]: - raise NotImplementedError() - class DishReceiver(Receiver): @@ -83,9 +80,6 @@ def register(self, handler: OnMessage) -> None: def deregister(self, handler: OnMessage) -> None: self._handlers.remove(handler) - def get_handlers(self) -> list[OnMessage]: - return self._handlers.copy() - def _receive_loop(self) -> None: while self._group: try: @@ -105,4 +99,4 @@ def _execute_handler(self, handler: OnMessage, message: dict[str, Any]) -> None: try: handler(message) except Exception as error: - log.warn('Error in message handler execution', data=message, group=self._group, error=error) + log.warn('Handler failed to process message', data=message, group=self._group, error=error) diff --git a/pyproject.toml b/pyproject.toml index 8622581..357e7b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,6 +4,9 @@ description = "A service advertizer/discovery protocol library using ZeroMQ" authors = [ { name = "Ferenc Nandor Janky & Attila Gombos", email = "info@effective-range.com" } ] +maintainers = [ + { name = "Ferenc Nandor Janky & Attila Gombos", email = "info@effective-range.com" } +] dependencies = [ "pyzmq @ git+https://github.com/EffectiveRange/pyzmq.git@v27.1.1", "python-context-logger @ git+https://github.com/EffectiveRange/python-context-logger.git@latest", @@ -25,3 +28,8 @@ build-backend = "setuptools.build_meta" [tool.setuptools_scm] version_scheme = "guess-next-dev" local_scheme = "node-and-date" + +[tool.pytest] +addopts = ["--verbose", "--capture=no"] +python_files = ["*Test.py"] +python_classes = ["*Test"] diff --git a/tests/defaultDiscovererTest.py b/tests/defaultDiscovererTest.py index 6dc8e70..4a21c34 100644 --- a/tests/defaultDiscovererTest.py +++ b/tests/defaultDiscovererTest.py @@ -4,6 +4,7 @@ from uuid import uuid4 from context_logger import setup_logging +from test_utility import wait_for_assertion from hello import ServiceInfo, Group, ServiceQuery, DefaultDiscoverer, Sender, Receiver, OnDiscoveryEvent, \ DiscoveryEventType, DiscoveryEvent @@ -74,7 +75,7 @@ def test_registers_event_handler(self): discoverer.register(handler) # Then - self.assertIn(handler, discoverer.get_handlers()) + self.assertIn(handler, discoverer._handlers) def test_deregisters_event_handler(self): # Given @@ -88,7 +89,7 @@ def test_deregisters_event_handler(self): discoverer.deregister(handler) # Then - self.assertNotIn(handler, discoverer.get_handlers()) + self.assertNotIn(handler, discoverer._handlers) def test_caches_service_and_calls_handler_when_receives_matching_info(self): # Given @@ -104,7 +105,9 @@ def test_caches_service_and_calls_handler_when_receives_matching_info(self): # Then self.assertEqual({SERVICE_INFO.uuid: SERVICE_INFO}, discoverer.get_services()) - handler.assert_called_once_with(DiscoveryEvent(SERVICE_INFO, DiscoveryEventType.DISCOVERED)) + wait_for_assertion(1, lambda: handler.assert_called_once_with( + DiscoveryEvent(GROUP, SERVICE_QUERY, SERVICE_INFO, DiscoveryEventType.DISCOVERED) + )) def test_updates_service_and_calls_handler_when_receives_matching_info(self): # Given @@ -125,7 +128,9 @@ def test_updates_service_and_calls_handler_when_receives_matching_info(self): # Then self.assertEqual({SERVICE_INFO.uuid: new_service_info}, discoverer.get_services()) - handler.assert_called_once_with(DiscoveryEvent(new_service_info, DiscoveryEventType.UPDATED)) + wait_for_assertion(1, lambda: handler.assert_called_once_with( + DiscoveryEvent(GROUP, SERVICE_QUERY, new_service_info, DiscoveryEventType.UPDATED) + )) def test_does_not_call_handler_when_service_info_not_changed(self): # Given @@ -159,7 +164,9 @@ def test_handles_handler_error_gracefully(self): # Then self.assertEqual({SERVICE_INFO.uuid: SERVICE_INFO}, discoverer.get_services()) - handler.assert_called_once_with(DiscoveryEvent(SERVICE_INFO, DiscoveryEventType.DISCOVERED)) + wait_for_assertion(1, lambda: handler.assert_called_once_with( + DiscoveryEvent(GROUP, SERVICE_QUERY, SERVICE_INFO, DiscoveryEventType.DISCOVERED) + )) def test_handles_invalid_message_gracefully(self): # Given diff --git a/tests/dishReceiverTest.py b/tests/dishReceiverTest.py index 047e665..24579bd 100644 --- a/tests/dishReceiverTest.py +++ b/tests/dishReceiverTest.py @@ -1,4 +1,5 @@ import unittest +from itertools import chain, repeat from unittest import TestCase from unittest.mock import MagicMock from uuid import uuid4 @@ -94,7 +95,7 @@ def test_registers_handler(self): receiver.register(handler) # Then - self.assertIn(handler, receiver.get_handlers()) + self.assertIn(handler, receiver._handlers) def test_deregisters_handler(self): # Given @@ -107,7 +108,7 @@ def test_deregisters_handler(self): receiver.deregister(handler) # Then - self.assertNotIn(handler, receiver.get_handlers()) + self.assertNotIn(handler, receiver._handlers) def test_calls_registered_handler_on_message(self): # Given @@ -118,9 +119,7 @@ def test_calls_registered_handler_on_message(self): with DishReceiver(context) as receiver: receiver._poller = MagicMock(spec=Poller) - receiver._poller.poll.side_effect = [ - {context.socket.return_value: POLLIN}, - ] + receiver._poller.poll.side_effect = chain([{context.socket.return_value: POLLIN}], repeat({})) receiver.register(handler) # When @@ -138,9 +137,7 @@ def test_handles_message_receive_error_gracefully(self): with DishReceiver(context) as receiver: receiver._poller = MagicMock(spec=Poller) - receiver._poller.poll.side_effect = [ - {context.socket.return_value: POLLIN}, - ] + receiver._poller.poll.side_effect = chain([{context.socket.return_value: POLLIN}], repeat({})) receiver.register(handler) # When @@ -159,9 +156,7 @@ def test_handles_handler_execution_error_gracefully(self): with DishReceiver(context) as receiver: receiver._poller = MagicMock(spec=Poller) - receiver._poller.poll.side_effect = [ - {context.socket.return_value: POLLIN}, - ] + receiver._poller.poll.side_effect = chain([{context.socket.return_value: POLLIN}], repeat({})) receiver.register(handler) # When diff --git a/tests/scheduledDiscovererTest.py b/tests/scheduledDiscovererTest.py index 43df237..d247291 100644 --- a/tests/scheduledDiscovererTest.py +++ b/tests/scheduledDiscovererTest.py @@ -87,21 +87,6 @@ def test_deregisters_event_handler(self): # Then discoverer.deregister.assert_called_once_with(handler) - def test_returns_event_handlers(self): - # Given - discoverer = MagicMock(spec=Discoverer) - timer = MagicMock(spec=IReusableTimer) - scheduled_discoverer = ScheduledDiscoverer(discoverer, timer) - scheduled_discoverer.start(GROUP, SERVICE_QUERY) - handler = MagicMock(spec=OnDiscoveryEvent) - scheduled_discoverer.register(handler) - - # When - result = scheduled_discoverer.get_handlers() - - # Then - self.assertEqual(discoverer.get_handlers(), result) - def test_sends_service_query(self): # Given discoverer = MagicMock(spec=Discoverer)