Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions clients/aws-sdk-qbusiness/tests/integration/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from smithy_aws_core.identity import EnvironmentCredentialsResolver

from aws_sdk_qbusiness.client import QBusinessClient
from aws_sdk_qbusiness.config import Config

REGION = "us-east-1"


def create_qbusiness_client(region: str) -> QBusinessClient:
"""Helper to create a QBusinessClient for a given region."""
return QBusinessClient(
config=Config(
endpoint_uri=f"https://qbusiness.{region}.api.aws",
Comment thread
jonathan343 marked this conversation as resolved.
region=region,
aws_credentials_identity_resolver=EnvironmentCredentialsResolver(),
)
)
218 changes: 218 additions & 0 deletions clients/aws-sdk-qbusiness/tests/integration/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,218 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Pytest fixtures for Q Business integration tests.

Creates and tears down a Q Business application (with an index and retriever)
once per test session. The ``qbusiness_app`` fixture provides the application
ID.
"""

import asyncio
import uuid

import pytest

from aws_sdk_qbusiness.models import (
ApplicationStatus,
CreateApplicationInput,
CreateIndexInput,
CreateRetrieverInput,
DeleteApplicationInput,
GetApplicationInput,
GetIndexInput,
GetRetrieverInput,
IndexStatus,
NativeIndexConfiguration,
RetrieverConfigurationNativeIndexConfiguration,
RetrieverStatus,
RetrieverType,
Tag,
ThrottlingException,
)

from . import REGION, create_qbusiness_client

# Tags applied to all resources so orphaned resources from interrupted
# test runs can be discovered and cleaned up.
_TAGS = [Tag(key="Purpose", value="IntegTest")]

_POLL_INTERVAL_SECONDS = 10
_POLL_TIMEOUT_SECONDS = 300

# ThrottlingException is not marked retryable in the service model, so the SDK
# does not retry it automatically. Used for CreateApplication under concurrency.
_THROTTLE_RETRY_DELAY_SECONDS = 5
_THROTTLE_RETRY_TIMEOUT_SECONDS = 300


async def _wait_for_application_active(client, application_id: str) -> None:
"""Wait for an Application to reach ACTIVE status.

Args:
client: A Q Business client.
application_id: The Application ID to poll.
"""
deadline = asyncio.get_running_loop().time() + _POLL_TIMEOUT_SECONDS
while asyncio.get_running_loop().time() < deadline:
response = await client.get_application(
input=GetApplicationInput(application_id=application_id)
)
if response.status == ApplicationStatus.ACTIVE:
return
if response.status in {ApplicationStatus.FAILED, ApplicationStatus.DELETING}:
raise RuntimeError(
f"Application {application_id} entered terminal state {response.status}"
)
await asyncio.sleep(_POLL_INTERVAL_SECONDS)
raise TimeoutError(f"Application {application_id} did not become ACTIVE in time")
Comment thread
jonathan343 marked this conversation as resolved.


async def _wait_for_index_active(client, application_id: str, index_id: str) -> None:
"""Wait for an Index to reach ACTIVE status.

Args:
client: A Q Business client.
application_id: The parent Application ID.
index_id: The Index ID to poll.
"""
deadline = asyncio.get_running_loop().time() + _POLL_TIMEOUT_SECONDS
while asyncio.get_running_loop().time() < deadline:
response = await client.get_index(
input=GetIndexInput(application_id=application_id, index_id=index_id)
)
if response.status == IndexStatus.ACTIVE:
return
if response.status in {IndexStatus.FAILED, IndexStatus.DELETING}:
raise RuntimeError(
f"Index {index_id} entered terminal state {response.status}"
)
await asyncio.sleep(_POLL_INTERVAL_SECONDS)
raise TimeoutError(f"Index {index_id} did not become ACTIVE in time")


async def _wait_for_retriever_active(
client, application_id: str, retriever_id: str
) -> None:
"""Wait for a Retriever to reach ACTIVE status.

Args:
client: A Q Business client.
application_id: The parent Application ID.
retriever_id: The Retriever ID to poll.
"""
deadline = asyncio.get_running_loop().time() + _POLL_TIMEOUT_SECONDS
while asyncio.get_running_loop().time() < deadline:
response = await client.get_retriever(
input=GetRetrieverInput(
application_id=application_id, retriever_id=retriever_id
)
)
if response.status == RetrieverStatus.ACTIVE:
return
if response.status == RetrieverStatus.FAILED:
raise RuntimeError(
f"Retriever {retriever_id} entered terminal state {response.status}"
)
await asyncio.sleep(_POLL_INTERVAL_SECONDS)
raise TimeoutError(f"Retriever {retriever_id} did not become ACTIVE in time")


async def _create_qbusiness_app(
client, app_name: str, index_name: str, retriever_name: str
) -> str:
"""Create a Q Business application with index and retriever.

Args:
client: A Q Business client.
app_name: The display name of the Application to create.
index_name: The display name of the Index to create.
retriever_name: The display name of the Retriever to create.

Returns:
The application ID.
"""
# ThrottlingException is not marked retryable in the service model, so the
# SDK does not retry it automatically. Retry here for concurrent test runs.
deadline = asyncio.get_running_loop().time() + _THROTTLE_RETRY_TIMEOUT_SECONDS
while True:
try:
response = await client.create_application(
input=CreateApplicationInput(
display_name=app_name,
identity_type="ANONYMOUS",
tags=_TAGS,
client_token=str(uuid.uuid4()),
)
)
break
except ThrottlingException:
if asyncio.get_running_loop().time() >= deadline:
raise
await asyncio.sleep(_THROTTLE_RETRY_DELAY_SECONDS)
application_id = response.application_id
assert application_id is not None
await _wait_for_application_active(client, application_id)

response = await client.create_index(
input=CreateIndexInput(
application_id=application_id,
display_name=index_name,
tags=_TAGS,
client_token=str(uuid.uuid4()),
)
)
index_id = response.index_id
assert index_id is not None
await _wait_for_index_active(client, application_id, index_id)

response = await client.create_retriever(
input=CreateRetrieverInput(
application_id=application_id,
display_name=retriever_name,
type=RetrieverType.NATIVE_INDEX,
configuration=RetrieverConfigurationNativeIndexConfiguration(
value=NativeIndexConfiguration(index_id=index_id)
),
tags=_TAGS,
client_token=str(uuid.uuid4()),
)
)
retriever_id = response.retriever_id
assert retriever_id is not None
await _wait_for_retriever_active(client, application_id, retriever_id)

return application_id


async def _delete_qbusiness_app(client, application_id: str | None) -> None:
"""Delete a Q Business application. Cascades to its index and retriever.

Args:
client: A Q Business client.
application_id: The Application ID to delete, or None if creation failed.
"""
if not application_id:
return
await client.delete_application(
input=DeleteApplicationInput(application_id=application_id)
)


@pytest.fixture(scope="session")
async def qbusiness_app():
"""Create a Q Business application for the test session and delete it after."""
unique_suffix = uuid.uuid4().hex[:16]
app_name = f"integ-test-qbusiness-app-{unique_suffix}"
index_name = f"integ-test-qbusiness-index-{unique_suffix}"
retriever_name = f"integ-test-qbusiness-retriever-{unique_suffix}"

client = create_qbusiness_client(REGION)
application_id: str | None = None
try:
application_id = await _create_qbusiness_app(
client, app_name, index_name, retriever_name
)
yield application_id
finally:
await _delete_qbusiness_app(client, application_id)
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Test bidirectional event stream handling for the Chat API."""

import asyncio
import uuid

from smithy_core.aio.eventstream import DuplexEventStream

from aws_sdk_qbusiness.models import (
ChatInput,
ChatInputStream,
ChatInputStreamConfigurationEvent,
ChatInputStreamEndOfInputEvent,
ChatInputStreamTextEvent,
ChatOutput,
ChatOutputStream,
ChatOutputStreamMetadataEvent,
ChatOutputStreamTextEvent,
ChatOutputStreamUnknown,
ConfigurationEvent,
EndOfInputEvent,
TextInputEvent,
)

from . import REGION, create_qbusiness_client


async def _send_chat_events(
stream: DuplexEventStream[ChatInputStream, ChatOutputStream, ChatOutput],
) -> None:
"""Send chat input events: configuration, text message, end of input."""
await stream.input_stream.send(
ChatInputStreamConfigurationEvent(
value=ConfigurationEvent(chat_mode="RETRIEVAL_MODE")
)
)

await stream.input_stream.send(
ChatInputStreamTextEvent(value=TextInputEvent(user_message="Hello"))
)

await stream.input_stream.send(
ChatInputStreamEndOfInputEvent(value=EndOfInputEvent())
)

await stream.input_stream.close()


async def _receive_chat_output(
stream: DuplexEventStream[ChatInputStream, ChatOutputStream, ChatOutput],
) -> tuple[bool, bool]:
"""Receive and validate chat output from the stream.

Returns:
Tuple of (got_text_events, got_metadata_event)
"""
got_text_events = False
got_metadata_event = False

_, output_stream = await stream.await_output()
if output_stream is None:
return got_text_events, got_metadata_event

async for event in output_stream:
if isinstance(event, ChatOutputStreamTextEvent):
got_text_events = True
assert event.value.system_message_type == "RESPONSE"
assert event.value.conversation_id is not None
assert event.value.user_message_id is not None
assert event.value.system_message_id is not None
assert event.value.system_message is not None
assert isinstance(event.value.system_message, str)
assert len(event.value.system_message) > 0
elif isinstance(event, ChatOutputStreamMetadataEvent):
got_metadata_event = True
assert event.value.conversation_id is not None
assert event.value.user_message_id is not None
assert event.value.system_message_id is not None
assert event.value.source_attributions is not None
assert event.value.final_text_message is not None
assert isinstance(event.value.final_text_message, str)
assert len(event.value.final_text_message) > 0
elif isinstance(event, ChatOutputStreamUnknown):
pass
else:
raise RuntimeError(
f"Received unexpected event type in stream: {type(event).__name__}"
)

return got_text_events, got_metadata_event


async def test_chat_bidirectional_streaming(qbusiness_app: str) -> None:
"""Test bidirectional streaming with text input and chat output."""
qbusiness_client = create_qbusiness_client(REGION)

stream = await qbusiness_client.chat(
input=ChatInput(application_id=qbusiness_app, client_token=str(uuid.uuid4()))
)

results = await asyncio.gather(
_send_chat_events(stream), _receive_chat_output(stream)
)
got_text_events, got_metadata_event = results[1]

assert got_text_events, "Expected to receive text output events"
assert got_metadata_event, "Expected to receive a metadata event"
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Test non-streaming output type handling."""

import uuid

from aws_sdk_qbusiness.models import ChatSyncInput, ChatSyncOutput

from . import REGION, create_qbusiness_client


async def test_chat_sync(qbusiness_app: str) -> None:
"""Test non-streaming ChatSync operation."""
qbusiness_client = create_qbusiness_client(REGION)

response = await qbusiness_client.chat_sync(
input=ChatSyncInput(
application_id=qbusiness_app,
user_message="Hello",
client_token=str(uuid.uuid4()),
)
)

assert isinstance(response, ChatSyncOutput)
assert response.conversation_id is not None
assert response.system_message is not None
assert isinstance(response.system_message, str)
assert len(response.system_message) > 0
assert response.system_message_id is not None
assert response.user_message_id is not None
assert response.source_attributions is not None
assert response.failed_attachments is not None