diff --git a/docs/reference/openapi.yaml b/docs/reference/openapi.yaml index 94a1f1540..cf9a1c05b 100644 --- a/docs/reference/openapi.yaml +++ b/docs/reference/openapi.yaml @@ -333,6 +333,14 @@ components: - type title: TaskResult type: object + TaskStatusEnum: + enum: + - PENDING + - COMPLETE + - ERROR + - RUNNING + title: TaskStatusEnum + type: string TasksListResponse: additionalProperties: false description: Diagnostic information on the tasks @@ -441,7 +449,7 @@ info: name: Apache 2.0 url: https://www.apache.org/licenses/LICENSE-2.0.html title: BlueAPI Control - version: 1.4.0 + version: 1.4.1 openapi: 3.1.0 paths: /api/v1/devices: @@ -599,15 +607,17 @@ paths: get: description: 'Retrieve tasks based on their status. - The status of a newly created task is ''unstarted''.' + The status of a newly created task is PENDING.' operationId: get_tasks_api_v1_tasks_get parameters: - in: query name: task_status required: false schema: + anyOf: + - $ref: '#/components/schemas/TaskStatusEnum' + - type: 'null' title: Task Status - type: string responses: '200': content: @@ -997,15 +1007,17 @@ paths: deprecated: true description: 'Retrieve tasks based on their status. - The status of a newly created task is ''unstarted''.' + The status of a newly created task is PENDING.' operationId: get_tasks_tasks_get parameters: - in: query name: task_status required: false schema: + anyOf: + - $ref: '#/components/schemas/TaskStatusEnum' + - type: 'null' title: Task Status - type: string responses: '200': content: diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 29cf242a9..585f19034 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -322,7 +322,7 @@ class ApplicationConfig(BlueapiBaseModel): """ #: API version to publish in OpenAPI schema - REST_API_VERSION: ClassVar[str] = "1.4.0" + REST_API_VERSION: ClassVar[str] = "1.4.1" LICENSE_INFO: ClassVar[dict[str, str]] = { "name": "Apache 2.0", diff --git a/src/blueapi/service/interface.py b/src/blueapi/service/interface.py index 335d00477..68bddb6c3 100644 --- a/src/blueapi/service/interface.py +++ b/src/blueapi/service/interface.py @@ -233,11 +233,6 @@ def remove_callback_when_task_finished( return task -def get_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]: - """Retrieve a list of tasks based on their status.""" - return worker().get_tasks_by_status(status) - - def get_active_task() -> TrackableTask | None: """Task the worker is currently running""" return worker().get_active_task() @@ -264,10 +259,10 @@ def cancel_active_task(failure: bool, reason: str | None) -> str: return worker().cancel_active_task(failure, reason) -def get_tasks() -> list[TrackableTask]: - """Return a list of all tasks on the worker, - any one of which can be triggered with begin_task""" - return worker().get_tasks() +def get_tasks_by_status(status: TaskStatusEnum | None = None) -> list[TrackableTask]: + """Retrieve a list of tasks based on their status. + Return a list of all tasks on the worker if status is None""" + return worker().get_tasks_by_status(status) def get_task_by_id(task_id: str) -> TrackableTask | None: diff --git a/src/blueapi/service/main.py b/src/blueapi/service/main.py index 1296bec58..4424e47ff 100644 --- a/src/blueapi/service/main.py +++ b/src/blueapi/service/main.py @@ -27,7 +27,6 @@ from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor from opentelemetry.trace import get_tracer_provider from pydantic import ValidationError -from pydantic.json_schema import SkipJsonSchema from starlette.responses import JSONResponse from super_state_machine.errors import TransitionError @@ -316,38 +315,18 @@ def delete_submitted_task( return TaskResponse(task_id=runner.run(interface.clear_task, task_id)) -@start_as_current_span(TRACER, "v") -def validate_task_status(v: str) -> TaskStatusEnum: - v_upper = v.upper() - if v_upper not in TaskStatusEnum.__members__: - raise ValueError("Invalid status query parameter") - return TaskStatusEnum(v_upper) - - @secure_router_v1.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK]) @start_as_current_span(TRACER) def get_tasks( runner: Annotated[WorkerDispatcher, Depends(_runner)], - task_status: str | SkipJsonSchema[None] = None, + task_status: TaskStatusEnum | None = None, ) -> TasksListResponse: """ Retrieve tasks based on their status. - The status of a newly created task is 'unstarted'. + The status of a newly created task is PENDING. """ - if task_status: - add_span_attributes({"status": task_status}) - try: - desired_status = validate_task_status(task_status) - except ValueError as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="Invalid status query parameter", - ) from e - - tasks = runner.run(interface.get_tasks_by_status, desired_status) - else: - tasks = runner.run(interface.get_tasks) + tasks = runner.run(interface.get_tasks_by_status, task_status) return TasksListResponse(tasks=tasks) diff --git a/src/blueapi/worker/task_worker.py b/src/blueapi/worker/task_worker.py index caa39fe7a..71dbb30a7 100644 --- a/src/blueapi/worker/task_worker.py +++ b/src/blueapi/worker/task_worker.py @@ -194,16 +194,6 @@ def cancel_active_task( add_span_attributes({"Task stopped": reason or default_reason}) return self._current.task_id - @start_as_current_span(TRACER) - def get_tasks(self) -> list[TrackableTask]: - """ - Return a list of all tasks on the worker, - any one of which can be triggered with begin_task. - Returns: - List[TrackableTask[T]]: List of task objects - """ - return list(self._pending_tasks.values()) + list(self._completed_tasks.values()) - @start_as_current_span(TRACER, "task_id") def get_task_by_id(self, task_id: str) -> TrackableTask | None: """ @@ -217,12 +207,15 @@ def get_task_by_id(self, task_id: str) -> TrackableTask | None: """ return self._pending_tasks.get(task_id, None) or self._completed_tasks[task_id] - @start_as_current_span(TRACER, "status") - def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask]: + @start_as_current_span(TRACER) + def get_tasks_by_status( + self, status: TaskStatusEnum | None = None + ) -> list[TrackableTask]: """ Retrieve a list of tasks based on their status. Args: - status TaskStatusEnum: The status to filter tasks by. + status Optional[TaskStatusEnum]: The status to filter tasks by. + If status is None return all tasks. Returns: list[TrackableTask]: A list of tasks that match the given status. """ @@ -236,6 +229,10 @@ def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask]: return [task for task in self._pending_tasks.values() if task.is_pending] elif status == TaskStatusEnum.COMPLETE: return list(self._completed_tasks.values()) + elif status is None: + return list(self._pending_tasks.values()) + list( + self._completed_tasks.values() + ) return [] @start_as_current_span(TRACER) diff --git a/tests/unit_tests/service/test_interface.py b/tests/unit_tests/service/test_interface.py index 39115a9a2..a7cf8edf2 100644 --- a/tests/unit_tests/service/test_interface.py +++ b/tests/unit_tests/service/test_interface.py @@ -264,26 +264,30 @@ def test_subscribers_removed_when_task_not_found( @patch("blueapi.service.interface.TaskWorker.get_tasks_by_status") def test_get_tasks_by_status(get_tasks_by_status_mock: MagicMock): - pending_task1 = TrackableTask(task_id="0", task=Task(name="pending_task1")) - pending_task2 = TrackableTask(task_id="1", task=Task(name="pending_task2")) - running_task = TrackableTask(task_id="2", task=Task(name="running_task")) + running_task = [TrackableTask(task_id="2", task=Task(name="running_task"))] + pending_task = [ + TrackableTask(task_id="0", task=Task(name="pending_task1")), + TrackableTask(task_id="1", task=Task(name="pending_task2")), + ] - def mock_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]: + def mock_tasks_by_status( + status: TaskStatusEnum | None = None, + ) -> list[TrackableTask]: if status == TaskStatusEnum.PENDING: - return [pending_task1, pending_task2] + return pending_task elif status == TaskStatusEnum.RUNNING: - return [running_task] - else: + return running_task + elif status == TaskStatusEnum.COMPLETE: return [] + else: + return pending_task + running_task get_tasks_by_status_mock.side_effect = mock_tasks_by_status - assert interface.get_tasks_by_status(TaskStatusEnum.PENDING) == [ - pending_task1, - pending_task2, - ] - assert interface.get_tasks_by_status(TaskStatusEnum.RUNNING) == [running_task] + assert interface.get_tasks_by_status(TaskStatusEnum.PENDING) == pending_task + assert interface.get_tasks_by_status(TaskStatusEnum.RUNNING) == running_task assert interface.get_tasks_by_status(TaskStatusEnum.COMPLETE) == [] + assert interface.get_tasks_by_status() == pending_task + running_task @patch("blueapi.service.interface.BlueskyContext.numtracker") @@ -334,18 +338,6 @@ def test_cancel_active_task(cancel_active_task_mock: MagicMock): cancel_active_task_mock.assert_called_once_with(fail, reason) -@patch("blueapi.service.interface.TaskWorker.get_tasks") -def test_get_tasks(get_tasks_mock: MagicMock): - tasks = [ - TrackableTask(task_id="0", task=Task(name="0")), - TrackableTask(task_id="1", task=Task(name="1")), - TrackableTask(task_id="2", task=Task(name="2")), - ] - get_tasks_mock.return_value = tasks - - assert interface.get_tasks() == tasks - - @pytest.mark.parametrize("tiled_enabled", [True, False]) @patch("blueapi.service.interface.context") @patch("blueapi.service.interface.config") diff --git a/tests/unit_tests/service/test_rest_api.py b/tests/unit_tests/service/test_rest_api.py index 1ddf2c6ca..91809ffb5 100644 --- a/tests/unit_tests/service/test_rest_api.py +++ b/tests/unit_tests/service/test_rest_api.py @@ -413,7 +413,7 @@ def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None: def test_get_tasks_by_status_invalid(client: TestClient) -> None: response = client.get("/tasks", params={"task_status": "AN_INVALID_STATUS"}) - assert response.status_code == status.HTTP_400_BAD_REQUEST + assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None: diff --git a/tests/unit_tests/worker/test_task_worker.py b/tests/unit_tests/worker/test_task_worker.py index f03ba7c5f..00bbf6316 100644 --- a/tests/unit_tests/worker/test_task_worker.py +++ b/tests/unit_tests/worker/test_task_worker.py @@ -181,9 +181,9 @@ def test_multi_start(inert_worker: TaskWorker) -> None: def test_submit_task( worker: TaskWorker, ) -> None: - assert worker.get_tasks() == [] + assert worker.get_tasks_by_status() == [] task_id = worker.submit_task(_SIMPLE_TASK) - assert worker.get_tasks() == [ + assert worker.get_tasks_by_status() == [ TrackableTask.model_construct( task_id=task_id, request_id=ANY, task=_SIMPLE_TASK ) @@ -191,15 +191,15 @@ def test_submit_task( def test_submit_multiple_tasks(worker: TaskWorker) -> None: - assert worker.get_tasks() == [] + assert worker.get_tasks_by_status() == [] task_id_1 = worker.submit_task(_SIMPLE_TASK) - assert worker.get_tasks() == [ + assert worker.get_tasks_by_status() == [ TrackableTask.model_construct( task_id=task_id_1, request_id=ANY, task=_SIMPLE_TASK ) ] task_id_2 = worker.submit_task(_LONG_TASK) - assert worker.get_tasks() == [ + assert worker.get_tasks_by_status() == [ TrackableTask.model_construct( task_id=task_id_1, request_id=ANY, task=_SIMPLE_TASK ), @@ -217,14 +217,14 @@ def test_stop_with_task_pending(inert_worker: TaskWorker) -> None: def test_restart_leaves_task_pending(worker: TaskWorker) -> None: task_id = worker.submit_task(_SIMPLE_TASK) - assert worker.get_tasks() == [ + assert worker.get_tasks_by_status() == [ TrackableTask.model_construct( task_id=task_id, request_id=ANY, task=_SIMPLE_TASK ) ] worker.stop() worker.start() - assert worker.get_tasks() == [ + assert worker.get_tasks_by_status() == [ TrackableTask.model_construct( task_id=task_id, request_id=ANY, task=_SIMPLE_TASK ) @@ -234,13 +234,13 @@ def test_restart_leaves_task_pending(worker: TaskWorker) -> None: def test_submit_before_start_pending(inert_worker: TaskWorker) -> None: task_id = inert_worker.submit_task(_SIMPLE_TASK) inert_worker.start() - assert inert_worker.get_tasks() == [ + assert inert_worker.get_tasks_by_status() == [ TrackableTask.model_construct( task_id=task_id, request_id=ANY, task=_SIMPLE_TASK ) ] inert_worker.stop() - assert inert_worker.get_tasks() == [ + assert inert_worker.get_tasks_by_status() == [ TrackableTask.model_construct( task_id=task_id, request_id=ANY, task=_SIMPLE_TASK ) @@ -249,13 +249,13 @@ def test_submit_before_start_pending(inert_worker: TaskWorker) -> None: def test_clear_task(worker: TaskWorker) -> None: task_id = worker.submit_task(_SIMPLE_TASK) - assert worker.get_tasks() == [ + assert worker.get_tasks_by_status() == [ TrackableTask.model_construct( task_id=task_id, request_id=ANY, task=_SIMPLE_TASK ) ] assert worker.clear_task(task_id) - assert worker.get_tasks() == [] + assert worker.get_tasks_by_status() == [] def test_clear_nonexistent_task(worker: TaskWorker) -> None: @@ -693,7 +693,7 @@ def test_submit_task_span_ok( exporter: JsonObjectSpanExporter, worker: TaskWorker, ) -> None: - assert worker.get_tasks() == [] + assert worker.get_tasks_by_status() == [] with asserting_span_exporter(exporter, "submit_task", "task.name", "task.params"): worker.submit_task(_SIMPLE_TASK)