diff --git a/clients/aws-sdk-qbusiness/tests/integration/__init__.py b/clients/aws-sdk-qbusiness/tests/integration/__init__.py new file mode 100644 index 0000000..4f44e2c --- /dev/null +++ b/clients/aws-sdk-qbusiness/tests/integration/__init__.py @@ -0,0 +1,20 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from smithy_aws_core.identity import EnvironmentCredentialsResolver + +from aws_sdk_qbusiness.client import QBusinessClient +from aws_sdk_qbusiness.config import Config + +REGION = "us-east-1" + + +def create_qbusiness_client(region: str) -> QBusinessClient: + """Helper to create a QBusinessClient for a given region.""" + return QBusinessClient( + config=Config( + endpoint_uri=f"https://qbusiness.{region}.api.aws", + region=region, + aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), + ) + ) diff --git a/clients/aws-sdk-qbusiness/tests/integration/conftest.py b/clients/aws-sdk-qbusiness/tests/integration/conftest.py new file mode 100644 index 0000000..c8d7b44 --- /dev/null +++ b/clients/aws-sdk-qbusiness/tests/integration/conftest.py @@ -0,0 +1,218 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Pytest fixtures for Q Business integration tests. + +Creates and tears down a Q Business application (with an index and retriever) +once per test session. The ``qbusiness_app`` fixture provides the application +ID. +""" + +import asyncio +import uuid + +import pytest + +from aws_sdk_qbusiness.models import ( + ApplicationStatus, + CreateApplicationInput, + CreateIndexInput, + CreateRetrieverInput, + DeleteApplicationInput, + GetApplicationInput, + GetIndexInput, + GetRetrieverInput, + IndexStatus, + NativeIndexConfiguration, + RetrieverConfigurationNativeIndexConfiguration, + RetrieverStatus, + RetrieverType, + Tag, + ThrottlingException, +) + +from . import REGION, create_qbusiness_client + +# Tags applied to all resources so orphaned resources from interrupted +# test runs can be discovered and cleaned up. +_TAGS = [Tag(key="Purpose", value="IntegTest")] + +_POLL_INTERVAL_SECONDS = 10 +_POLL_TIMEOUT_SECONDS = 300 + +# ThrottlingException is not marked retryable in the service model, so the SDK +# does not retry it automatically. Used for CreateApplication under concurrency. +_THROTTLE_RETRY_DELAY_SECONDS = 5 +_THROTTLE_RETRY_TIMEOUT_SECONDS = 300 + + +async def _wait_for_application_active(client, application_id: str) -> None: + """Wait for an Application to reach ACTIVE status. + + Args: + client: A Q Business client. + application_id: The Application ID to poll. + """ + deadline = asyncio.get_running_loop().time() + _POLL_TIMEOUT_SECONDS + while asyncio.get_running_loop().time() < deadline: + response = await client.get_application( + input=GetApplicationInput(application_id=application_id) + ) + if response.status == ApplicationStatus.ACTIVE: + return + if response.status in {ApplicationStatus.FAILED, ApplicationStatus.DELETING}: + raise RuntimeError( + f"Application {application_id} entered terminal state {response.status}" + ) + await asyncio.sleep(_POLL_INTERVAL_SECONDS) + raise TimeoutError(f"Application {application_id} did not become ACTIVE in time") + + +async def _wait_for_index_active(client, application_id: str, index_id: str) -> None: + """Wait for an Index to reach ACTIVE status. + + Args: + client: A Q Business client. + application_id: The parent Application ID. + index_id: The Index ID to poll. + """ + deadline = asyncio.get_running_loop().time() + _POLL_TIMEOUT_SECONDS + while asyncio.get_running_loop().time() < deadline: + response = await client.get_index( + input=GetIndexInput(application_id=application_id, index_id=index_id) + ) + if response.status == IndexStatus.ACTIVE: + return + if response.status in {IndexStatus.FAILED, IndexStatus.DELETING}: + raise RuntimeError( + f"Index {index_id} entered terminal state {response.status}" + ) + await asyncio.sleep(_POLL_INTERVAL_SECONDS) + raise TimeoutError(f"Index {index_id} did not become ACTIVE in time") + + +async def _wait_for_retriever_active( + client, application_id: str, retriever_id: str +) -> None: + """Wait for a Retriever to reach ACTIVE status. + + Args: + client: A Q Business client. + application_id: The parent Application ID. + retriever_id: The Retriever ID to poll. + """ + deadline = asyncio.get_running_loop().time() + _POLL_TIMEOUT_SECONDS + while asyncio.get_running_loop().time() < deadline: + response = await client.get_retriever( + input=GetRetrieverInput( + application_id=application_id, retriever_id=retriever_id + ) + ) + if response.status == RetrieverStatus.ACTIVE: + return + if response.status == RetrieverStatus.FAILED: + raise RuntimeError( + f"Retriever {retriever_id} entered terminal state {response.status}" + ) + await asyncio.sleep(_POLL_INTERVAL_SECONDS) + raise TimeoutError(f"Retriever {retriever_id} did not become ACTIVE in time") + + +async def _create_qbusiness_app( + client, app_name: str, index_name: str, retriever_name: str +) -> str: + """Create a Q Business application with index and retriever. + + Args: + client: A Q Business client. + app_name: The display name of the Application to create. + index_name: The display name of the Index to create. + retriever_name: The display name of the Retriever to create. + + Returns: + The application ID. + """ + # ThrottlingException is not marked retryable in the service model, so the + # SDK does not retry it automatically. Retry here for concurrent test runs. + deadline = asyncio.get_running_loop().time() + _THROTTLE_RETRY_TIMEOUT_SECONDS + while True: + try: + response = await client.create_application( + input=CreateApplicationInput( + display_name=app_name, + identity_type="ANONYMOUS", + tags=_TAGS, + client_token=str(uuid.uuid4()), + ) + ) + break + except ThrottlingException: + if asyncio.get_running_loop().time() >= deadline: + raise + await asyncio.sleep(_THROTTLE_RETRY_DELAY_SECONDS) + application_id = response.application_id + assert application_id is not None + await _wait_for_application_active(client, application_id) + + response = await client.create_index( + input=CreateIndexInput( + application_id=application_id, + display_name=index_name, + tags=_TAGS, + client_token=str(uuid.uuid4()), + ) + ) + index_id = response.index_id + assert index_id is not None + await _wait_for_index_active(client, application_id, index_id) + + response = await client.create_retriever( + input=CreateRetrieverInput( + application_id=application_id, + display_name=retriever_name, + type=RetrieverType.NATIVE_INDEX, + configuration=RetrieverConfigurationNativeIndexConfiguration( + value=NativeIndexConfiguration(index_id=index_id) + ), + tags=_TAGS, + client_token=str(uuid.uuid4()), + ) + ) + retriever_id = response.retriever_id + assert retriever_id is not None + await _wait_for_retriever_active(client, application_id, retriever_id) + + return application_id + + +async def _delete_qbusiness_app(client, application_id: str | None) -> None: + """Delete a Q Business application. Cascades to its index and retriever. + + Args: + client: A Q Business client. + application_id: The Application ID to delete, or None if creation failed. + """ + if not application_id: + return + await client.delete_application( + input=DeleteApplicationInput(application_id=application_id) + ) + + +@pytest.fixture(scope="session") +async def qbusiness_app(): + """Create a Q Business application for the test session and delete it after.""" + unique_suffix = uuid.uuid4().hex[:16] + app_name = f"integ-test-qbusiness-app-{unique_suffix}" + index_name = f"integ-test-qbusiness-index-{unique_suffix}" + retriever_name = f"integ-test-qbusiness-retriever-{unique_suffix}" + + client = create_qbusiness_client(REGION) + application_id: str | None = None + try: + application_id = await _create_qbusiness_app( + client, app_name, index_name, retriever_name + ) + yield application_id + finally: + await _delete_qbusiness_app(client, application_id) diff --git a/clients/aws-sdk-qbusiness/tests/integration/test_bidirectional_streaming.py b/clients/aws-sdk-qbusiness/tests/integration/test_bidirectional_streaming.py new file mode 100644 index 0000000..3da172e --- /dev/null +++ b/clients/aws-sdk-qbusiness/tests/integration/test_bidirectional_streaming.py @@ -0,0 +1,109 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Test bidirectional event stream handling for the Chat API.""" + +import asyncio +import uuid + +from smithy_core.aio.eventstream import DuplexEventStream + +from aws_sdk_qbusiness.models import ( + ChatInput, + ChatInputStream, + ChatInputStreamConfigurationEvent, + ChatInputStreamEndOfInputEvent, + ChatInputStreamTextEvent, + ChatOutput, + ChatOutputStream, + ChatOutputStreamMetadataEvent, + ChatOutputStreamTextEvent, + ChatOutputStreamUnknown, + ConfigurationEvent, + EndOfInputEvent, + TextInputEvent, +) + +from . import REGION, create_qbusiness_client + + +async def _send_chat_events( + stream: DuplexEventStream[ChatInputStream, ChatOutputStream, ChatOutput], +) -> None: + """Send chat input events: configuration, text message, end of input.""" + await stream.input_stream.send( + ChatInputStreamConfigurationEvent( + value=ConfigurationEvent(chat_mode="RETRIEVAL_MODE") + ) + ) + + await stream.input_stream.send( + ChatInputStreamTextEvent(value=TextInputEvent(user_message="Hello")) + ) + + await stream.input_stream.send( + ChatInputStreamEndOfInputEvent(value=EndOfInputEvent()) + ) + + await stream.input_stream.close() + + +async def _receive_chat_output( + stream: DuplexEventStream[ChatInputStream, ChatOutputStream, ChatOutput], +) -> tuple[bool, bool]: + """Receive and validate chat output from the stream. + + Returns: + Tuple of (got_text_events, got_metadata_event) + """ + got_text_events = False + got_metadata_event = False + + _, output_stream = await stream.await_output() + if output_stream is None: + return got_text_events, got_metadata_event + + async for event in output_stream: + if isinstance(event, ChatOutputStreamTextEvent): + got_text_events = True + assert event.value.system_message_type == "RESPONSE" + assert event.value.conversation_id is not None + assert event.value.user_message_id is not None + assert event.value.system_message_id is not None + assert event.value.system_message is not None + assert isinstance(event.value.system_message, str) + assert len(event.value.system_message) > 0 + elif isinstance(event, ChatOutputStreamMetadataEvent): + got_metadata_event = True + assert event.value.conversation_id is not None + assert event.value.user_message_id is not None + assert event.value.system_message_id is not None + assert event.value.source_attributions is not None + assert event.value.final_text_message is not None + assert isinstance(event.value.final_text_message, str) + assert len(event.value.final_text_message) > 0 + elif isinstance(event, ChatOutputStreamUnknown): + pass + else: + raise RuntimeError( + f"Received unexpected event type in stream: {type(event).__name__}" + ) + + return got_text_events, got_metadata_event + + +async def test_chat_bidirectional_streaming(qbusiness_app: str) -> None: + """Test bidirectional streaming with text input and chat output.""" + qbusiness_client = create_qbusiness_client(REGION) + + stream = await qbusiness_client.chat( + input=ChatInput(application_id=qbusiness_app, client_token=str(uuid.uuid4())) + ) + + results = await asyncio.gather( + _send_chat_events(stream), _receive_chat_output(stream) + ) + got_text_events, got_metadata_event = results[1] + + assert got_text_events, "Expected to receive text output events" + assert got_metadata_event, "Expected to receive a metadata event" diff --git a/clients/aws-sdk-qbusiness/tests/integration/test_non_streaming.py b/clients/aws-sdk-qbusiness/tests/integration/test_non_streaming.py new file mode 100644 index 0000000..3fb429b --- /dev/null +++ b/clients/aws-sdk-qbusiness/tests/integration/test_non_streaming.py @@ -0,0 +1,33 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Test non-streaming output type handling.""" + +import uuid + +from aws_sdk_qbusiness.models import ChatSyncInput, ChatSyncOutput + +from . import REGION, create_qbusiness_client + + +async def test_chat_sync(qbusiness_app: str) -> None: + """Test non-streaming ChatSync operation.""" + qbusiness_client = create_qbusiness_client(REGION) + + response = await qbusiness_client.chat_sync( + input=ChatSyncInput( + application_id=qbusiness_app, + user_message="Hello", + client_token=str(uuid.uuid4()), + ) + ) + + assert isinstance(response, ChatSyncOutput) + assert response.conversation_id is not None + assert response.system_message is not None + assert isinstance(response.system_message, str) + assert len(response.system_message) > 0 + assert response.system_message_id is not None + assert response.user_message_id is not None + assert response.source_attributions is not None + assert response.failed_attachments is not None