-
Notifications
You must be signed in to change notification settings - Fork 16
Add generated Q Business client and integration tests #54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,20 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from smithy_aws_core.identity import EnvironmentCredentialsResolver | ||
|
|
||
| from aws_sdk_qbusiness.client import QBusinessClient | ||
| from aws_sdk_qbusiness.config import Config | ||
|
|
||
| REGION = "us-east-1" | ||
|
|
||
|
|
||
| def create_qbusiness_client(region: str) -> QBusinessClient: | ||
| """Helper to create a QBusinessClient for a given region.""" | ||
| return QBusinessClient( | ||
| config=Config( | ||
| endpoint_uri=f"https://qbusiness.{region}.api.aws", | ||
| region=region, | ||
| aws_credentials_identity_resolver=EnvironmentCredentialsResolver(), | ||
| ) | ||
| ) | ||
218 changes: 218 additions & 0 deletions
218
clients/aws-sdk-qbusiness/tests/integration/conftest.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,218 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Pytest fixtures for Q Business integration tests. | ||
|
|
||
| Creates and tears down a Q Business application (with an index and retriever) | ||
| once per test session. The ``qbusiness_app`` fixture provides the application | ||
| ID. | ||
| """ | ||
|
|
||
| import asyncio | ||
| import uuid | ||
|
|
||
| import pytest | ||
|
|
||
| from aws_sdk_qbusiness.models import ( | ||
| ApplicationStatus, | ||
| CreateApplicationInput, | ||
| CreateIndexInput, | ||
| CreateRetrieverInput, | ||
| DeleteApplicationInput, | ||
| GetApplicationInput, | ||
| GetIndexInput, | ||
| GetRetrieverInput, | ||
| IndexStatus, | ||
| NativeIndexConfiguration, | ||
| RetrieverConfigurationNativeIndexConfiguration, | ||
| RetrieverStatus, | ||
| RetrieverType, | ||
| Tag, | ||
| ThrottlingException, | ||
| ) | ||
|
|
||
| from . import REGION, create_qbusiness_client | ||
|
|
||
| # Tags applied to all resources so orphaned resources from interrupted | ||
| # test runs can be discovered and cleaned up. | ||
| _TAGS = [Tag(key="Purpose", value="IntegTest")] | ||
|
|
||
| _POLL_INTERVAL_SECONDS = 10 | ||
| _POLL_TIMEOUT_SECONDS = 300 | ||
|
|
||
| # ThrottlingException is not marked retryable in the service model, so the SDK | ||
| # does not retry it automatically. Used for CreateApplication under concurrency. | ||
| _THROTTLE_RETRY_DELAY_SECONDS = 5 | ||
| _THROTTLE_RETRY_TIMEOUT_SECONDS = 300 | ||
|
|
||
|
|
||
| async def _wait_for_application_active(client, application_id: str) -> None: | ||
| """Wait for an Application to reach ACTIVE status. | ||
|
|
||
| Args: | ||
| client: A Q Business client. | ||
| application_id: The Application ID to poll. | ||
| """ | ||
| deadline = asyncio.get_running_loop().time() + _POLL_TIMEOUT_SECONDS | ||
| while asyncio.get_running_loop().time() < deadline: | ||
| response = await client.get_application( | ||
| input=GetApplicationInput(application_id=application_id) | ||
| ) | ||
| if response.status == ApplicationStatus.ACTIVE: | ||
| return | ||
| if response.status in {ApplicationStatus.FAILED, ApplicationStatus.DELETING}: | ||
| raise RuntimeError( | ||
| f"Application {application_id} entered terminal state {response.status}" | ||
| ) | ||
| await asyncio.sleep(_POLL_INTERVAL_SECONDS) | ||
| raise TimeoutError(f"Application {application_id} did not become ACTIVE in time") | ||
|
jonathan343 marked this conversation as resolved.
|
||
|
|
||
|
|
||
| async def _wait_for_index_active(client, application_id: str, index_id: str) -> None: | ||
| """Wait for an Index to reach ACTIVE status. | ||
|
|
||
| Args: | ||
| client: A Q Business client. | ||
| application_id: The parent Application ID. | ||
| index_id: The Index ID to poll. | ||
| """ | ||
| deadline = asyncio.get_running_loop().time() + _POLL_TIMEOUT_SECONDS | ||
| while asyncio.get_running_loop().time() < deadline: | ||
| response = await client.get_index( | ||
| input=GetIndexInput(application_id=application_id, index_id=index_id) | ||
| ) | ||
| if response.status == IndexStatus.ACTIVE: | ||
| return | ||
| if response.status in {IndexStatus.FAILED, IndexStatus.DELETING}: | ||
| raise RuntimeError( | ||
| f"Index {index_id} entered terminal state {response.status}" | ||
| ) | ||
| await asyncio.sleep(_POLL_INTERVAL_SECONDS) | ||
| raise TimeoutError(f"Index {index_id} did not become ACTIVE in time") | ||
|
|
||
|
|
||
| async def _wait_for_retriever_active( | ||
| client, application_id: str, retriever_id: str | ||
| ) -> None: | ||
| """Wait for a Retriever to reach ACTIVE status. | ||
|
|
||
| Args: | ||
| client: A Q Business client. | ||
| application_id: The parent Application ID. | ||
| retriever_id: The Retriever ID to poll. | ||
| """ | ||
| deadline = asyncio.get_running_loop().time() + _POLL_TIMEOUT_SECONDS | ||
| while asyncio.get_running_loop().time() < deadline: | ||
| response = await client.get_retriever( | ||
| input=GetRetrieverInput( | ||
| application_id=application_id, retriever_id=retriever_id | ||
| ) | ||
| ) | ||
| if response.status == RetrieverStatus.ACTIVE: | ||
| return | ||
| if response.status == RetrieverStatus.FAILED: | ||
| raise RuntimeError( | ||
| f"Retriever {retriever_id} entered terminal state {response.status}" | ||
| ) | ||
| await asyncio.sleep(_POLL_INTERVAL_SECONDS) | ||
| raise TimeoutError(f"Retriever {retriever_id} did not become ACTIVE in time") | ||
|
|
||
|
|
||
| async def _create_qbusiness_app( | ||
| client, app_name: str, index_name: str, retriever_name: str | ||
| ) -> str: | ||
| """Create a Q Business application with index and retriever. | ||
|
|
||
| Args: | ||
| client: A Q Business client. | ||
| app_name: The display name of the Application to create. | ||
| index_name: The display name of the Index to create. | ||
| retriever_name: The display name of the Retriever to create. | ||
|
|
||
| Returns: | ||
| The application ID. | ||
| """ | ||
| # ThrottlingException is not marked retryable in the service model, so the | ||
| # SDK does not retry it automatically. Retry here for concurrent test runs. | ||
| deadline = asyncio.get_running_loop().time() + _THROTTLE_RETRY_TIMEOUT_SECONDS | ||
| while True: | ||
| try: | ||
| response = await client.create_application( | ||
| input=CreateApplicationInput( | ||
| display_name=app_name, | ||
| identity_type="ANONYMOUS", | ||
| tags=_TAGS, | ||
| client_token=str(uuid.uuid4()), | ||
| ) | ||
| ) | ||
| break | ||
| except ThrottlingException: | ||
| if asyncio.get_running_loop().time() >= deadline: | ||
| raise | ||
| await asyncio.sleep(_THROTTLE_RETRY_DELAY_SECONDS) | ||
| application_id = response.application_id | ||
| assert application_id is not None | ||
| await _wait_for_application_active(client, application_id) | ||
|
|
||
| response = await client.create_index( | ||
| input=CreateIndexInput( | ||
| application_id=application_id, | ||
| display_name=index_name, | ||
| tags=_TAGS, | ||
| client_token=str(uuid.uuid4()), | ||
| ) | ||
| ) | ||
| index_id = response.index_id | ||
| assert index_id is not None | ||
| await _wait_for_index_active(client, application_id, index_id) | ||
|
|
||
| response = await client.create_retriever( | ||
| input=CreateRetrieverInput( | ||
| application_id=application_id, | ||
| display_name=retriever_name, | ||
| type=RetrieverType.NATIVE_INDEX, | ||
| configuration=RetrieverConfigurationNativeIndexConfiguration( | ||
| value=NativeIndexConfiguration(index_id=index_id) | ||
| ), | ||
| tags=_TAGS, | ||
| client_token=str(uuid.uuid4()), | ||
| ) | ||
| ) | ||
| retriever_id = response.retriever_id | ||
| assert retriever_id is not None | ||
| await _wait_for_retriever_active(client, application_id, retriever_id) | ||
|
|
||
| return application_id | ||
|
|
||
|
|
||
| async def _delete_qbusiness_app(client, application_id: str | None) -> None: | ||
| """Delete a Q Business application. Cascades to its index and retriever. | ||
|
|
||
| Args: | ||
| client: A Q Business client. | ||
| application_id: The Application ID to delete, or None if creation failed. | ||
| """ | ||
| if not application_id: | ||
| return | ||
| await client.delete_application( | ||
| input=DeleteApplicationInput(application_id=application_id) | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture(scope="session") | ||
| async def qbusiness_app(): | ||
| """Create a Q Business application for the test session and delete it after.""" | ||
| unique_suffix = uuid.uuid4().hex[:16] | ||
| app_name = f"integ-test-qbusiness-app-{unique_suffix}" | ||
| index_name = f"integ-test-qbusiness-index-{unique_suffix}" | ||
| retriever_name = f"integ-test-qbusiness-retriever-{unique_suffix}" | ||
|
|
||
| client = create_qbusiness_client(REGION) | ||
| application_id: str | None = None | ||
| try: | ||
| application_id = await _create_qbusiness_app( | ||
| client, app_name, index_name, retriever_name | ||
| ) | ||
| yield application_id | ||
| finally: | ||
| await _delete_qbusiness_app(client, application_id) | ||
109 changes: 109 additions & 0 deletions
109
clients/aws-sdk-qbusiness/tests/integration/test_bidirectional_streaming.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,109 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Test bidirectional event stream handling for the Chat API.""" | ||
|
|
||
| import asyncio | ||
| import uuid | ||
|
|
||
| from smithy_core.aio.eventstream import DuplexEventStream | ||
|
|
||
| from aws_sdk_qbusiness.models import ( | ||
| ChatInput, | ||
| ChatInputStream, | ||
| ChatInputStreamConfigurationEvent, | ||
| ChatInputStreamEndOfInputEvent, | ||
| ChatInputStreamTextEvent, | ||
| ChatOutput, | ||
| ChatOutputStream, | ||
| ChatOutputStreamMetadataEvent, | ||
| ChatOutputStreamTextEvent, | ||
| ChatOutputStreamUnknown, | ||
| ConfigurationEvent, | ||
| EndOfInputEvent, | ||
| TextInputEvent, | ||
| ) | ||
|
|
||
| from . import REGION, create_qbusiness_client | ||
|
|
||
|
|
||
| async def _send_chat_events( | ||
| stream: DuplexEventStream[ChatInputStream, ChatOutputStream, ChatOutput], | ||
| ) -> None: | ||
| """Send chat input events: configuration, text message, end of input.""" | ||
| await stream.input_stream.send( | ||
| ChatInputStreamConfigurationEvent( | ||
| value=ConfigurationEvent(chat_mode="RETRIEVAL_MODE") | ||
| ) | ||
| ) | ||
|
|
||
| await stream.input_stream.send( | ||
| ChatInputStreamTextEvent(value=TextInputEvent(user_message="Hello")) | ||
| ) | ||
|
|
||
| await stream.input_stream.send( | ||
| ChatInputStreamEndOfInputEvent(value=EndOfInputEvent()) | ||
| ) | ||
|
|
||
| await stream.input_stream.close() | ||
|
|
||
|
|
||
| async def _receive_chat_output( | ||
| stream: DuplexEventStream[ChatInputStream, ChatOutputStream, ChatOutput], | ||
| ) -> tuple[bool, bool]: | ||
| """Receive and validate chat output from the stream. | ||
|
|
||
| Returns: | ||
| Tuple of (got_text_events, got_metadata_event) | ||
| """ | ||
| got_text_events = False | ||
| got_metadata_event = False | ||
|
|
||
| _, output_stream = await stream.await_output() | ||
| if output_stream is None: | ||
| return got_text_events, got_metadata_event | ||
|
|
||
| async for event in output_stream: | ||
| if isinstance(event, ChatOutputStreamTextEvent): | ||
| got_text_events = True | ||
| assert event.value.system_message_type == "RESPONSE" | ||
| assert event.value.conversation_id is not None | ||
| assert event.value.user_message_id is not None | ||
| assert event.value.system_message_id is not None | ||
| assert event.value.system_message is not None | ||
| assert isinstance(event.value.system_message, str) | ||
| assert len(event.value.system_message) > 0 | ||
| elif isinstance(event, ChatOutputStreamMetadataEvent): | ||
| got_metadata_event = True | ||
| assert event.value.conversation_id is not None | ||
| assert event.value.user_message_id is not None | ||
| assert event.value.system_message_id is not None | ||
| assert event.value.source_attributions is not None | ||
| assert event.value.final_text_message is not None | ||
| assert isinstance(event.value.final_text_message, str) | ||
| assert len(event.value.final_text_message) > 0 | ||
| elif isinstance(event, ChatOutputStreamUnknown): | ||
| pass | ||
| else: | ||
| raise RuntimeError( | ||
| f"Received unexpected event type in stream: {type(event).__name__}" | ||
| ) | ||
|
|
||
| return got_text_events, got_metadata_event | ||
|
|
||
|
|
||
| async def test_chat_bidirectional_streaming(qbusiness_app: str) -> None: | ||
| """Test bidirectional streaming with text input and chat output.""" | ||
| qbusiness_client = create_qbusiness_client(REGION) | ||
|
|
||
| stream = await qbusiness_client.chat( | ||
| input=ChatInput(application_id=qbusiness_app, client_token=str(uuid.uuid4())) | ||
| ) | ||
|
|
||
| results = await asyncio.gather( | ||
| _send_chat_events(stream), _receive_chat_output(stream) | ||
| ) | ||
| got_text_events, got_metadata_event = results[1] | ||
|
|
||
| assert got_text_events, "Expected to receive text output events" | ||
| assert got_metadata_event, "Expected to receive a metadata event" |
33 changes: 33 additions & 0 deletions
33
clients/aws-sdk-qbusiness/tests/integration/test_non_streaming.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,33 @@ | ||
| # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Test non-streaming output type handling.""" | ||
|
|
||
| import uuid | ||
|
|
||
| from aws_sdk_qbusiness.models import ChatSyncInput, ChatSyncOutput | ||
|
|
||
| from . import REGION, create_qbusiness_client | ||
|
|
||
|
|
||
| async def test_chat_sync(qbusiness_app: str) -> None: | ||
| """Test non-streaming ChatSync operation.""" | ||
| qbusiness_client = create_qbusiness_client(REGION) | ||
|
|
||
| response = await qbusiness_client.chat_sync( | ||
| input=ChatSyncInput( | ||
| application_id=qbusiness_app, | ||
| user_message="Hello", | ||
| client_token=str(uuid.uuid4()), | ||
| ) | ||
| ) | ||
|
|
||
| assert isinstance(response, ChatSyncOutput) | ||
| assert response.conversation_id is not None | ||
| assert response.system_message is not None | ||
| assert isinstance(response.system_message, str) | ||
| assert len(response.system_message) > 0 | ||
| assert response.system_message_id is not None | ||
| assert response.user_message_id is not None | ||
| assert response.source_attributions is not None | ||
| assert response.failed_attachments is not None |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.