Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 17 additions & 5 deletions docs/reference/openapi.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,14 @@ components:
- type
title: TaskResult
type: object
TaskStatusEnum:
enum:
- PENDING
- COMPLETE
- ERROR
- RUNNING
title: TaskStatusEnum
type: string
TasksListResponse:
additionalProperties: false
description: Diagnostic information on the tasks
Expand Down Expand Up @@ -441,7 +449,7 @@ info:
name: Apache 2.0
url: https://www.apache.org/licenses/LICENSE-2.0.html
title: BlueAPI Control
version: 1.4.0
version: 1.4.1
openapi: 3.1.0
paths:
/api/v1/devices:
Expand Down Expand Up @@ -599,15 +607,17 @@ paths:
get:
description: 'Retrieve tasks based on their status.

The status of a newly created task is ''unstarted''.'
The status of a newly created task is PENDING.'
operationId: get_tasks_api_v1_tasks_get
parameters:
- in: query
name: task_status
required: false
schema:
anyOf:
- $ref: '#/components/schemas/TaskStatusEnum'
- type: 'null'
title: Task Status
type: string
responses:
'200':
content:
Expand Down Expand Up @@ -997,15 +1007,17 @@ paths:
deprecated: true
description: 'Retrieve tasks based on their status.

The status of a newly created task is ''unstarted''.'
The status of a newly created task is PENDING.'
operationId: get_tasks_tasks_get
parameters:
- in: query
name: task_status
required: false
schema:
anyOf:
- $ref: '#/components/schemas/TaskStatusEnum'
- type: 'null'
title: Task Status
type: string
responses:
'200':
content:
Expand Down
2 changes: 1 addition & 1 deletion src/blueapi/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ class ApplicationConfig(BlueapiBaseModel):
"""

#: API version to publish in OpenAPI schema
REST_API_VERSION: ClassVar[str] = "1.4.0"
REST_API_VERSION: ClassVar[str] = "1.4.1"

LICENSE_INFO: ClassVar[dict[str, str]] = {
"name": "Apache 2.0",
Expand Down
13 changes: 4 additions & 9 deletions src/blueapi/service/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,11 +233,6 @@ def remove_callback_when_task_finished(
return task


def get_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]:
"""Retrieve a list of tasks based on their status."""
return worker().get_tasks_by_status(status)


def get_active_task() -> TrackableTask | None:
"""Task the worker is currently running"""
return worker().get_active_task()
Expand All @@ -264,10 +259,10 @@ def cancel_active_task(failure: bool, reason: str | None) -> str:
return worker().cancel_active_task(failure, reason)


def get_tasks() -> list[TrackableTask]:
"""Return a list of all tasks on the worker,
any one of which can be triggered with begin_task"""
return worker().get_tasks()
def get_tasks_by_status(status: TaskStatusEnum | None = None) -> list[TrackableTask]:
"""Retrieve a list of tasks based on their status.
Return a list of all tasks on the worker if status is None"""
return worker().get_tasks_by_status(status)


def get_task_by_id(task_id: str) -> TrackableTask | None:
Expand Down
27 changes: 3 additions & 24 deletions src/blueapi/service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
from opentelemetry.trace import get_tracer_provider
from pydantic import ValidationError
from pydantic.json_schema import SkipJsonSchema
from starlette.responses import JSONResponse
from super_state_machine.errors import TransitionError

Expand Down Expand Up @@ -316,38 +315,18 @@ def delete_submitted_task(
return TaskResponse(task_id=runner.run(interface.clear_task, task_id))


@start_as_current_span(TRACER, "v")
def validate_task_status(v: str) -> TaskStatusEnum:
v_upper = v.upper()
if v_upper not in TaskStatusEnum.__members__:
raise ValueError("Invalid status query parameter")
return TaskStatusEnum(v_upper)


@secure_router_v1.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK])
@secure_router.get("/tasks", status_code=status.HTTP_200_OK, tags=[Tag.TASK])
@start_as_current_span(TRACER)
def get_tasks(
runner: Annotated[WorkerDispatcher, Depends(_runner)],
task_status: str | SkipJsonSchema[None] = None,
task_status: TaskStatusEnum | None = None,
) -> TasksListResponse:
"""
Retrieve tasks based on their status.
The status of a newly created task is 'unstarted'.
The status of a newly created task is PENDING.
"""
if task_status:
add_span_attributes({"status": task_status})
try:
desired_status = validate_task_status(task_status)
except ValueError as e:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid status query parameter",
) from e

tasks = runner.run(interface.get_tasks_by_status, desired_status)
else:
tasks = runner.run(interface.get_tasks)
tasks = runner.run(interface.get_tasks_by_status, task_status)
return TasksListResponse(tasks=tasks)


Expand Down
23 changes: 10 additions & 13 deletions src/blueapi/worker/task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,16 +194,6 @@ def cancel_active_task(
add_span_attributes({"Task stopped": reason or default_reason})
return self._current.task_id

@start_as_current_span(TRACER)
def get_tasks(self) -> list[TrackableTask]:
"""
Return a list of all tasks on the worker,
any one of which can be triggered with begin_task.
Returns:
List[TrackableTask[T]]: List of task objects
"""
return list(self._pending_tasks.values()) + list(self._completed_tasks.values())

@start_as_current_span(TRACER, "task_id")
def get_task_by_id(self, task_id: str) -> TrackableTask | None:
"""
Expand All @@ -217,12 +207,15 @@ def get_task_by_id(self, task_id: str) -> TrackableTask | None:
"""
return self._pending_tasks.get(task_id, None) or self._completed_tasks[task_id]

@start_as_current_span(TRACER, "status")
def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask]:
@start_as_current_span(TRACER)
def get_tasks_by_status(
self, status: TaskStatusEnum | None = None
) -> list[TrackableTask]:
"""
Retrieve a list of tasks based on their status.
Args:
status TaskStatusEnum: The status to filter tasks by.
status Optional[TaskStatusEnum]: The status to filter tasks by.
If status is None return all tasks.
Returns:
list[TrackableTask]: A list of tasks that match the given status.
"""
Expand All @@ -236,6 +229,10 @@ def get_tasks_by_status(self, status: TaskStatusEnum) -> list[TrackableTask]:
return [task for task in self._pending_tasks.values() if task.is_pending]
elif status == TaskStatusEnum.COMPLETE:
return list(self._completed_tasks.values())
elif status is None:
return list(self._pending_tasks.values()) + list(
self._completed_tasks.values()
)
return []

@start_as_current_span(TRACER)
Expand Down
40 changes: 16 additions & 24 deletions tests/unit_tests/service/test_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,26 +264,30 @@ def test_subscribers_removed_when_task_not_found(

@patch("blueapi.service.interface.TaskWorker.get_tasks_by_status")
def test_get_tasks_by_status(get_tasks_by_status_mock: MagicMock):
pending_task1 = TrackableTask(task_id="0", task=Task(name="pending_task1"))
pending_task2 = TrackableTask(task_id="1", task=Task(name="pending_task2"))
running_task = TrackableTask(task_id="2", task=Task(name="running_task"))
running_task = [TrackableTask(task_id="2", task=Task(name="running_task"))]
pending_task = [
TrackableTask(task_id="0", task=Task(name="pending_task1")),
TrackableTask(task_id="1", task=Task(name="pending_task2")),
]

def mock_tasks_by_status(status: TaskStatusEnum) -> list[TrackableTask]:
def mock_tasks_by_status(
status: TaskStatusEnum | None = None,
) -> list[TrackableTask]:
if status == TaskStatusEnum.PENDING:
return [pending_task1, pending_task2]
return pending_task
elif status == TaskStatusEnum.RUNNING:
return [running_task]
else:
return running_task
elif status == TaskStatusEnum.COMPLETE:
return []
else:
return pending_task + running_task

get_tasks_by_status_mock.side_effect = mock_tasks_by_status

assert interface.get_tasks_by_status(TaskStatusEnum.PENDING) == [
pending_task1,
pending_task2,
]
assert interface.get_tasks_by_status(TaskStatusEnum.RUNNING) == [running_task]
assert interface.get_tasks_by_status(TaskStatusEnum.PENDING) == pending_task
assert interface.get_tasks_by_status(TaskStatusEnum.RUNNING) == running_task
assert interface.get_tasks_by_status(TaskStatusEnum.COMPLETE) == []
assert interface.get_tasks_by_status() == pending_task + running_task


@patch("blueapi.service.interface.BlueskyContext.numtracker")
Expand Down Expand Up @@ -334,18 +338,6 @@ def test_cancel_active_task(cancel_active_task_mock: MagicMock):
cancel_active_task_mock.assert_called_once_with(fail, reason)


@patch("blueapi.service.interface.TaskWorker.get_tasks")
def test_get_tasks(get_tasks_mock: MagicMock):
tasks = [
TrackableTask(task_id="0", task=Task(name="0")),
TrackableTask(task_id="1", task=Task(name="1")),
TrackableTask(task_id="2", task=Task(name="2")),
]
get_tasks_mock.return_value = tasks

assert interface.get_tasks() == tasks


@pytest.mark.parametrize("tiled_enabled", [True, False])
@patch("blueapi.service.interface.context")
@patch("blueapi.service.interface.config")
Expand Down
2 changes: 1 addition & 1 deletion tests/unit_tests/service/test_rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,7 +413,7 @@ def test_get_tasks_by_status(mock_runner: Mock, client: TestClient) -> None:

def test_get_tasks_by_status_invalid(client: TestClient) -> None:
response = client.get("/tasks", params={"task_status": "AN_INVALID_STATUS"})
assert response.status_code == status.HTTP_400_BAD_REQUEST
assert response.status_code == status.HTTP_422_UNPROCESSABLE_CONTENT


def test_delete_submitted_task(mock_runner: Mock, client: TestClient) -> None:
Expand Down
24 changes: 12 additions & 12 deletions tests/unit_tests/worker/test_task_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,25 +181,25 @@ def test_multi_start(inert_worker: TaskWorker) -> None:
def test_submit_task(
worker: TaskWorker,
) -> None:
assert worker.get_tasks() == []
assert worker.get_tasks_by_status() == []
task_id = worker.submit_task(_SIMPLE_TASK)
assert worker.get_tasks() == [
assert worker.get_tasks_by_status() == [
TrackableTask.model_construct(
task_id=task_id, request_id=ANY, task=_SIMPLE_TASK
)
]


def test_submit_multiple_tasks(worker: TaskWorker) -> None:
assert worker.get_tasks() == []
assert worker.get_tasks_by_status() == []
task_id_1 = worker.submit_task(_SIMPLE_TASK)
assert worker.get_tasks() == [
assert worker.get_tasks_by_status() == [
TrackableTask.model_construct(
task_id=task_id_1, request_id=ANY, task=_SIMPLE_TASK
)
]
task_id_2 = worker.submit_task(_LONG_TASK)
assert worker.get_tasks() == [
assert worker.get_tasks_by_status() == [
TrackableTask.model_construct(
task_id=task_id_1, request_id=ANY, task=_SIMPLE_TASK
),
Expand All @@ -217,14 +217,14 @@ def test_stop_with_task_pending(inert_worker: TaskWorker) -> None:

def test_restart_leaves_task_pending(worker: TaskWorker) -> None:
task_id = worker.submit_task(_SIMPLE_TASK)
assert worker.get_tasks() == [
assert worker.get_tasks_by_status() == [
TrackableTask.model_construct(
task_id=task_id, request_id=ANY, task=_SIMPLE_TASK
)
]
worker.stop()
worker.start()
assert worker.get_tasks() == [
assert worker.get_tasks_by_status() == [
TrackableTask.model_construct(
task_id=task_id, request_id=ANY, task=_SIMPLE_TASK
)
Expand All @@ -234,13 +234,13 @@ def test_restart_leaves_task_pending(worker: TaskWorker) -> None:
def test_submit_before_start_pending(inert_worker: TaskWorker) -> None:
task_id = inert_worker.submit_task(_SIMPLE_TASK)
inert_worker.start()
assert inert_worker.get_tasks() == [
assert inert_worker.get_tasks_by_status() == [
TrackableTask.model_construct(
task_id=task_id, request_id=ANY, task=_SIMPLE_TASK
)
]
inert_worker.stop()
assert inert_worker.get_tasks() == [
assert inert_worker.get_tasks_by_status() == [
TrackableTask.model_construct(
task_id=task_id, request_id=ANY, task=_SIMPLE_TASK
)
Expand All @@ -249,13 +249,13 @@ def test_submit_before_start_pending(inert_worker: TaskWorker) -> None:

def test_clear_task(worker: TaskWorker) -> None:
task_id = worker.submit_task(_SIMPLE_TASK)
assert worker.get_tasks() == [
assert worker.get_tasks_by_status() == [
TrackableTask.model_construct(
task_id=task_id, request_id=ANY, task=_SIMPLE_TASK
)
]
assert worker.clear_task(task_id)
assert worker.get_tasks() == []
assert worker.get_tasks_by_status() == []


def test_clear_nonexistent_task(worker: TaskWorker) -> None:
Expand Down Expand Up @@ -693,7 +693,7 @@ def test_submit_task_span_ok(
exporter: JsonObjectSpanExporter,
worker: TaskWorker,
) -> None:
assert worker.get_tasks() == []
assert worker.get_tasks_by_status() == []
with asserting_span_exporter(exporter, "submit_task", "task.name", "task.params"):
worker.submit_task(_SIMPLE_TASK)

Expand Down
Loading