diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index d4a18b473..a1f19d0b9 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -5,6 +5,7 @@ on: push: branches: - main + - feat/gdb-rw # temporary permissions: id-token: write # This is required for requesting the JWT @@ -17,7 +18,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: [ "3.11", "3.12", "3.13" ] + python-version: [ "3.11", "3.12", "3.13" ] environment: [ "mysql", "pg" ] steps: diff --git a/README.md b/README.md index 7d36c6473..dbe4619bf 100644 --- a/README.md +++ b/README.md @@ -121,6 +121,10 @@ The following table lists the connection properties used with the AWS Advanced P | `secrets_manager_secret_username_key` | [Secrets Manager Plugin](docs/using-the-python-wrapper/using-plugins/UsingTheAwsSecretsManagerPlugin.md) | | `secrets_manager_secret_password_key` | [Secrets Manager Plugin](docs/using-the-python-wrapper/using-plugins/UsingTheAwsSecretsManagerPlugin.md) | | `reader_host_selector_strategy` | [Connection Strategy](docs/using-the-python-wrapper/using-plugins/UsingTheReadWriteSplittingPlugin.md#connection-strategies) | +| `gdb_rw_home_region` | [GDB Read/Write Splitting Plugin](docs/using-the-python-wrapper/using-plugins/UsingTheGdbReadWriteSplittingPlugin.md) | +| `gdb_rw_restrict_writer_to_home_region` | [GDB Read/Write Splitting Plugin](docs/using-the-python-wrapper/using-plugins/UsingTheGdbReadWriteSplittingPlugin.md) | +| `gdb_rw_restrict_reader_to_home_region` | [GDB Read/Write Splitting Plugin](docs/using-the-python-wrapper/using-plugins/UsingTheGdbReadWriteSplittingPlugin.md) | +| `gdb_enable_global_write_forwarding` | [GDB Read/Write Splitting Plugin](docs/using-the-python-wrapper/using-plugins/UsingTheGdbReadWriteSplittingPlugin.md) | | `db_user` | [Federated Authentication Plugin](docs/using-the-python-wrapper/using-plugins/UsingTheFederatedAuthenticationPlugin.md) | | `idp_username` | [Federated Authentication Plugin](docs/using-the-python-wrapper/using-plugins/UsingTheFederatedAuthenticationPlugin.md) | | `idp_password` | [Federated Authentication Plugin](docs/using-the-python-wrapper/using-plugins/UsingTheFederatedAuthenticationPlugin.md) | diff --git a/aws_advanced_python_wrapper/failover_plugin.py b/aws_advanced_python_wrapper/failover_plugin.py index cfe14f6e7..062000c26 100644 --- a/aws_advanced_python_wrapper/failover_plugin.py +++ b/aws_advanced_python_wrapper/failover_plugin.py @@ -181,21 +181,24 @@ def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]): if not self._enable_failover_setting: return - msg = "" - for key in changes: - msg += f"\n\tHost '{key}': {changes[key]}" - logger.debug("FailoverPlugin.Changes", msg) - - current_host = self._plugin_service.current_host_info - if current_host is not None: - if self._is_host_still_valid(current_host.url, changes): - return - - for alias in current_host.aliases: - if self._is_host_still_valid(alias + '/', changes): + try: + msg = "" + for key in changes: + msg += f"\n\tHost '{key}': {changes[key]}" + logger.debug("FailoverPlugin.Changes", msg) + + current_host = self._plugin_service.current_host_info + if current_host is not None: + if self._is_host_still_valid(current_host.url, changes): return - logger.debug("FailoverPlugin.InvalidHost", current_host) + for alias in current_host.aliases: + if self._is_host_still_valid(alias + '/', changes): + return + + logger.debug("FailoverPlugin.InvalidHost", current_host) + finally: + self._stale_dns_helper.notify_host_list_changed(changes) def connect( self, diff --git a/aws_advanced_python_wrapper/gdb_read_write_splitting_plugin.py b/aws_advanced_python_wrapper/gdb_read_write_splitting_plugin.py new file mode 100644 index 000000000..6cf6f9fde --- /dev/null +++ b/aws_advanced_python_wrapper/gdb_read_write_splitting_plugin.py @@ -0,0 +1,181 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Callable, List, Optional + +from aws_advanced_python_wrapper.errors import ReadWriteSplittingError +from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.read_write_splitting_plugin import \ + ReadWriteSplittingPlugin +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.hostinfo import HostInfo + from aws_advanced_python_wrapper.pep249 import Connection + from aws_advanced_python_wrapper.plugin_service import PluginService + +logger = Logger(__name__) + + +class GdbReadWriteSplittingPlugin(ReadWriteSplittingPlugin): + """Read/write splitting plugin for Aurora Global Database. + + Extends the topology-based :class:`ReadWriteSplittingPlugin` to keep + reader and writer connections inside a configured home region. When + enabled, the plugin will refuse to switch to a writer or reader instance + that lives outside that region. Optionally, when Global Write + Forwarding is enabled, the plugin will keep the existing reader + connection in a secondary region instead of failing. + """ + + def __init__(self, plugin_service: PluginService, props: Properties): + super().__init__(plugin_service, props) + + self._rds_utils: RdsUtils = RdsUtils() + self._restrict_writer_to_home_region: bool = ( + WrapperProperties.GDB_RW_RESTRICT_WRITER_TO_HOME_REGION.get_bool(props) + ) + self._restrict_reader_to_home_region: bool = ( + WrapperProperties.GDB_RW_RESTRICT_READER_TO_HOME_REGION.get_bool(props) + ) + self._enable_global_write_forwarding: bool = ( + WrapperProperties.GDB_ENABLE_GLOBAL_WRITE_FORWARDING.get_bool(props) + ) + self._home_region: Optional[str] = None + self._initialized: bool = False + + def connect( + self, + target_driver_func: Callable, + driver_dialect: DriverDialect, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + connect_func: Callable, + ) -> Connection: + self._init_settings(host_info, props) + return super().connect( + target_driver_func, + driver_dialect, + host_info, + props, + is_initial_connection, + connect_func, + ) + + def _init_settings(self, init_host_info: HostInfo, props: Properties) -> None: + if self._initialized: + return + self._initialized = True + + home_region = WrapperProperties.GDB_RW_HOME_REGION.get(props) + if not home_region: + url_type = self._rds_utils.identify_rds_type(init_host_info.host) + if url_type is not None and url_type.has_region: + home_region = self._rds_utils.get_rds_region(init_host_info.host) + + if not home_region: + raise ReadWriteSplittingError( + Messages.get_formatted( + "GdbReadWriteSplittingPlugin.MissingHomeRegion", + init_host_info.host, + ) + ) + + self._home_region = home_region + logger.debug( + "GdbReadWriteSplittingPlugin.ParameterValue", + WrapperProperties.GDB_RW_HOME_REGION.name, + self._home_region, + ) + + def _initialize_writer_connection(self) -> None: + writer_host = self._get_writer_host_info() + if writer_host is not None and self._is_writer_outside_home_region(writer_host): + if self._enable_global_write_forwarding: + logger.debug( + "GdbReadWriteSplittingPlugin.EnabledGwf", + self._rds_utils.get_rds_region(writer_host.host), + ) + return + + raise ReadWriteSplittingError( + Messages.get_formatted( + "GdbReadWriteSplittingPlugin.CantConnectWriterOutOfHomeRegion", + writer_host.host, + self._home_region, + ) + ) + + super()._initialize_writer_connection() + + def _set_writer_connection( + self, writer_conn: Connection, writer_host_info: HostInfo + ) -> None: + if self._is_writer_outside_home_region(writer_host_info): + raise ReadWriteSplittingError( + Messages.get_formatted( + "GdbReadWriteSplittingPlugin.CantConnectWriterOutOfHomeRegion", + writer_host_info.host, + self._home_region, + ) + ) + super()._set_writer_connection(writer_conn, writer_host_info) + + def _get_reader_host_candidates(self) -> List[HostInfo]: + if not self._restrict_reader_to_home_region: + return super()._get_reader_host_candidates() + + hosts_in_region = [ + host + for host in self._plugin_service.hosts + if self._is_in_home_region(host) + ] + + if not hosts_in_region: + raise ReadWriteSplittingError( + Messages.get_formatted( + "GdbReadWriteSplittingPlugin.NoAvailableReadersInHomeRegion", + self._home_region, + ) + ) + + return hosts_in_region + + def _is_writer_outside_home_region(self, host_info: HostInfo) -> bool: + return ( + self._restrict_writer_to_home_region + and not self._is_in_home_region(host_info) + ) + + def _is_in_home_region(self, host_info: HostInfo) -> bool: + if self._home_region is None: + return True + host_region = self._rds_utils.get_rds_region(host_info.host) + if host_region is None: + return False + return host_region.casefold() == self._home_region.casefold() + + +class GdbReadWriteSplittingPluginFactory(PluginFactory): + @staticmethod + def get_instance(plugin_service: PluginService, props: Properties) -> Plugin: + return GdbReadWriteSplittingPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 82189c3d1..ea650068b 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -72,6 +72,8 @@ from aws_advanced_python_wrapper.execute_time_plugin import \ ExecuteTimePluginFactory from aws_advanced_python_wrapper.failover_plugin import FailoverPluginFactory +from aws_advanced_python_wrapper.gdb_read_write_splitting_plugin import \ + GdbReadWriteSplittingPluginFactory from aws_advanced_python_wrapper.host_availability import HostAvailability from aws_advanced_python_wrapper.host_list_provider import ( ConnectionStringHostListProvider, HostListProvider, @@ -830,6 +832,7 @@ class PluginManager(CanReleaseResources): "failover_v2": FailoverV2PluginFactory, "read_write_splitting": ReadWriteSplittingPluginFactory, "srw": SimpleReadWriteSplittingPluginFactory, + "gdb_rw": GdbReadWriteSplittingPluginFactory, "fastest_response_strategy": FastestResponseStrategyPluginFactory, "stale_dns": StaleDnsPluginFactory, "custom_endpoint": CustomEndpointPluginFactory, @@ -855,6 +858,7 @@ class PluginManager(CanReleaseResources): StaleDnsPluginFactory: 200, ReadWriteSplittingPluginFactory: 300, SimpleReadWriteSplittingPluginFactory: 310, + GdbReadWriteSplittingPluginFactory: 320, FailoverPluginFactory: 400, FailoverV2PluginFactory: 410, HostMonitoringPluginFactory: 500, diff --git a/aws_advanced_python_wrapper/read_write_splitting_plugin.py b/aws_advanced_python_wrapper/read_write_splitting_plugin.py index 1ef1e4edf..23bd01742 100644 --- a/aws_advanced_python_wrapper/read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/read_write_splitting_plugin.py @@ -16,7 +16,7 @@ from abc import abstractmethod from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable, Optional, Set, Tuple +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Set, Tuple if TYPE_CHECKING: from aws_advanced_python_wrapper.driver_dialect import DriverDialect @@ -344,9 +344,10 @@ def _close_connections(self, close_only_if_idle: bool = True): self._close_connection(self._writer_connection, close_only_if_idle) @staticmethod - def log_and_raise_exception(log_msg: str): - logger.error(log_msg) - raise ReadWriteSplittingError(Messages.get(log_msg)) + def log_and_raise_exception(log_msg: str, *args): + msg = Messages.get_formatted(log_msg, *args) if args else Messages.get(log_msg) + logger.error(log_msg, *args) + raise ReadWriteSplittingError(msg) @staticmethod def _is_connection_usable(conn: Optional[Connection], driver_dialect: DriverDialect): @@ -509,14 +510,14 @@ def _initialize_writer_connection(self): writer_host = self._get_writer_host_info() if writer_host is None: self.log_and_raise_exception( - "ReadWriteSplittingPlugin.FailedToConnectToWriter" + "ReadWriteSplittingPlugin.NoWriterFound" ) return conn = self._plugin_service.connect(writer_host, self._properties, self) if conn is None: self.log_and_raise_exception( - "ReadWriteSplittingPlugin.FailedToConnectToWriter" + "ReadWriteSplittingPlugin.FailedToConnectToWriter", writer_host.url ) return @@ -595,16 +596,25 @@ def _can_host_be_used(self, host_info: HostInfo) -> bool: hosts = [h.get_host_and_port() for h in self._hosts] return host_info.get_host_and_port() in hosts + def _get_reader_host_candidates(self) -> List[HostInfo]: + """Return the list of host candidates used when selecting a reader. + + Subclasses can override this method to filter the candidate list, + for example to restrict readers to a specific region. + """ + return list(self._plugin_service.hosts) + def _open_new_reader_connection( self, ) -> tuple[Optional[Connection], Optional[HostInfo]]: conn: Optional[Connection] = None reader_host: Optional[HostInfo] = None - conn_attempts = len(self._plugin_service.hosts) * 2 + host_candidates = self._get_reader_host_candidates() + conn_attempts = len(host_candidates) * 2 for _ in range(conn_attempts): host = self._plugin_service.get_host_info_by_strategy( - HostRole.READER, self._reader_selector_strategy + HostRole.READER, self._reader_selector_strategy, host_candidates ) if host is not None: try: diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index b5f0da3c2..eb2a8a4cc 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -425,6 +425,12 @@ WeightedRandomHostSelector.WeightedRandomInvalidDefaultWeight=[WeightedRandomHos SimpleReadWriteSplittingPlugin.MissingRequiredConfigParameter=[SimpleReadWriteSplittingPlugin] Configuration parameter {} is required. SimpleReadWriteSplittingPlugin.IncorrectConfiguration=[SimpleReadWriteSplittingPlugin] Unable to verify connections with this current configuration. Ensure a correct value is provided to the configuration parameter {}. +GdbReadWriteSplittingPlugin.MissingHomeRegion=[GdbReadWriteSplittingPlugin] Unable to parse home region from endpoint '{}'. Please ensure you have set the 'gdb_rw_home_region' connection parameter. +GdbReadWriteSplittingPlugin.CantConnectWriterOutOfHomeRegion=[GdbReadWriteSplittingPlugin] Writer connection to '{}' is not allowed since it is out of home region '{}'. +GdbReadWriteSplittingPlugin.NoAvailableReadersInHomeRegion=[GdbReadWriteSplittingPlugin] No available reader nodes in home region '{}'. +GdbReadWriteSplittingPlugin.ParameterValue=[GdbReadWriteSplittingPlugin] {}={} +GdbReadWriteSplittingPlugin.EnabledGwf=[GdbReadWriteSplittingPlugin] The current primary writer region is '{}' and is not within the home region. Keeping the current connection and letting Global Write Forwarding redirect writes to the primary region. + SqlAlchemyPooledConnectionProvider.PoolNone=[SqlAlchemyPooledConnectionProvider] Attempted to find or create a pool for '{}' but the result of the attempt evaluated to None. SqlAlchemyPooledConnectionProvider.UnableToCreateDefaultKey=[SqlAlchemyPooledConnectionProvider] Unable to create a default key for internal connection pools. By default, the user parameter is used, but the given user evaluated to None or the empty string (""). Please ensure you have passed a valid user in the connection properties. diff --git a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py index 76bb1b8e5..7bb34913a 100644 --- a/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/simple_read_write_splitting_plugin.py @@ -142,7 +142,8 @@ def _initialize_writer_connection(self): if conn is None: self.log_and_raise_exception( - "ReadWriteSplittingPlugin.FailedToConnectToWriter" + "ReadWriteSplittingPlugin.FailedToConnectToWriter", + self._write_endpoint_host_info.url ) return diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index efec458f4..02b9b970e 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -615,6 +615,32 @@ class WrapperProperties: 1000, ) + # Global Database Read/Write Splitting + GDB_RW_HOME_REGION = WrapperProperty( + "gdb_rw_home_region", + "Specifies the home region for read/write splitting in a Global Database setup.", + None, + ) + + GDB_RW_RESTRICT_WRITER_TO_HOME_REGION = WrapperProperty( + "gdb_rw_restrict_writer_to_home_region", + "Prevents connections to a writer instance outside of the defined home region.", + True, + ) + + GDB_RW_RESTRICT_READER_TO_HOME_REGION = WrapperProperty( + "gdb_rw_restrict_reader_to_home_region", + "Prevents connections to a reader instance outside of the defined home region.", + True, + ) + + GDB_ENABLE_GLOBAL_WRITE_FORWARDING = WrapperProperty( + "gdb_enable_global_write_forwarding", + "Set to True to enable Global Write Forwarding when connected to a " + "reader connection in a secondary global region.", + False, + ) + class PropertiesUtils: _MONITORING_PROPERTY_PREFIX = "monitoring-" diff --git a/aws_advanced_python_wrapper/utils/rds_utils.py b/aws_advanced_python_wrapper/utils/rds_utils.py index e8cce41ce..1b297894e 100644 --- a/aws_advanced_python_wrapper/utils/rds_utils.py +++ b/aws_advanced_python_wrapper/utils/rds_utils.py @@ -15,7 +15,7 @@ from __future__ import annotations from re import Match, search, sub -from typing import Dict, Optional +from typing import Callable, ClassVar, Dict, Optional from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType @@ -108,7 +108,7 @@ class RdsUtils: r"(?Pcluster-|cluster-ro-)+" \ r"(?P[a-zA-Z0-9]+\.rds\.(?P[a-zA-Z0-9\-]+)" \ r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$" - ELB_PATTERN = r"^(?.+)\.elb\.((?[a-zA-Z0-9\-]+)\.amazonaws\.com)$" + ELB_PATTERN = r"^(?P.+)\.elb\.((?P[a-zA-Z0-9\-]+)\.amazonaws\.com)$" IP_V4 = r"^(([1-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){1}" \ r"(([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){2}([0-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])" @@ -127,77 +127,100 @@ class RdsUtils: CACHE_DNS_PATTERNS: Dict[str, Match[str]] = {} CACHE_PATTERNS: Dict[str, str] = {} + _prepare_host_func: ClassVar[Optional[Callable[[str], Optional[str]]]] = None + + @staticmethod + def set_prepare_host_func(func: Optional[Callable[[str], Optional[str]]]): + RdsUtils._prepare_host_func = func + + @staticmethod + def reset_prepare_host_func(): + RdsUtils._prepare_host_func = None + + @staticmethod + def _get_prepared_host(host: Optional[str]) -> Optional[str]: + func = RdsUtils._prepare_host_func + if func is None or host is None: + return host + prepared = func(host) + return host if prepared is None else prepared + def is_rds_cluster_dns(self, host: str) -> bool: - dns_group = self._get_dns_group(host) + dns_group = self._get_dns_group(self._get_prepared_host(host)) return dns_group is not None and dns_group.casefold() in ["cluster-", "cluster-ro-"] def is_rds_custom_cluster_dns(self, host: str) -> bool: - dns_group = self._get_dns_group(host) + dns_group = self._get_dns_group(self._get_prepared_host(host)) return dns_group is not None and dns_group.casefold() == "cluster-custom-" - def is_rds_dns(self, host: str) -> bool: - if not host or not host.strip(): + def is_rds_dns(self, host: Optional[str]) -> bool: + prepared_host = self._get_prepared_host(host) + if not prepared_host or not prepared_host.strip(): return False - pattern = self._find(host, [RdsUtils.AURORA_DNS_PATTERN, - RdsUtils.AURORA_CHINA_DNS_PATTERN, - RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, - RdsUtils.AURORA_GOV_DNS_PATTERN]) + pattern = self._find(prepared_host, [RdsUtils.AURORA_DNS_PATTERN, + RdsUtils.AURORA_CHINA_DNS_PATTERN, + RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, + RdsUtils.AURORA_GOV_DNS_PATTERN]) group = self._get_regex_group(pattern, RdsUtils.DNS_GROUP) if group: - RdsUtils.CACHE_PATTERNS[host] = group + RdsUtils.CACHE_PATTERNS[prepared_host] = group return pattern is not None - def is_rds_instance(self, host: str) -> bool: - return self._get_dns_group(host) is None and self.is_rds_dns(host) + def is_rds_instance(self, host: Optional[str]) -> bool: + prepared_host = self._get_prepared_host(host) + return self._get_dns_group(prepared_host) is None and self.is_rds_dns(prepared_host) def is_rds_proxy_dns(self, host: str) -> bool: - dns_group = self._get_dns_group(host) + dns_group = self._get_dns_group(self._get_prepared_host(host)) return dns_group is not None and dns_group.casefold() == "proxy-" def get_rds_instance_host_pattern(self, host: str) -> str: - if not host or not host.strip(): + prepared_host = self._get_prepared_host(host) + if not prepared_host or not prepared_host.strip(): return "?" - match = self._get_group(host, RdsUtils.DOMAIN_GROUP) + match = self._get_group(prepared_host, RdsUtils.DOMAIN_GROUP) if match: return f"?.{match}" return "?" def get_rds_region(self, host: Optional[str]): - if not host or not host.strip(): + prepared_host = self._get_prepared_host(host) + if not prepared_host or not prepared_host.strip(): return None - group = self._get_group(host, RdsUtils.REGION_GROUP) + group = self._get_group(prepared_host, RdsUtils.REGION_GROUP) if group: return group - elb_matcher = search(RdsUtils.ELB_PATTERN, host) + elb_matcher = search(RdsUtils.ELB_PATTERN, prepared_host) if elb_matcher: return elb_matcher.group(RdsUtils.REGION_GROUP) return None def is_writer_cluster_dns(self, host: str) -> bool: - dns_group = self._get_dns_group(host) + dns_group = self._get_dns_group(self._get_prepared_host(host)) return dns_group is not None and dns_group.casefold() == "cluster-" def is_reader_cluster_dns(self, host: str) -> bool: - dns_group = self._get_dns_group(host) + dns_group = self._get_dns_group(self._get_prepared_host(host)) return dns_group is not None and dns_group.casefold() == "cluster-ro-" def is_global_db_writer_cluster_dns(self, host: str) -> bool: - dns_group = self._get_dns_group(host) + dns_group = self._get_dns_group(self._get_prepared_host(host)) return dns_group is not None and dns_group.casefold() == "global-" def is_limitless_database_shard_group_dns(self, host: str) -> bool: - dns_group = self._get_dns_group(host) + dns_group = self._get_dns_group(self._get_prepared_host(host)) return dns_group is not None and dns_group.casefold() == "shardgrp-" def get_rds_cluster_host_url(self, host: str): - if not host or not host.strip(): + prepared_host = self._get_prepared_host(host) + if not prepared_host or not prepared_host.strip(): return None for pattern in [RdsUtils.AURORA_CLUSTER_PATTERN, @@ -205,29 +228,31 @@ def get_rds_cluster_host_url(self, host: str): RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, RdsUtils.AURORA_GOV_DNS_PATTERN, RdsUtils.AURORA_LIMITLESS_CLUSTER_PATTERN]: - if m := search(pattern, host): + if m := search(pattern, prepared_host): group = self._get_regex_group(m, RdsUtils.DNS_GROUP) if group is not None: if pattern == RdsUtils.AURORA_LIMITLESS_CLUSTER_PATTERN: - return sub(pattern, r"\g.shardgrp-\g", host) + return sub(pattern, r"\g.shardgrp-\g", prepared_host) else: - return sub(pattern, r"\g.cluster-\g", host) + return sub(pattern, r"\g.cluster-\g", prepared_host) return None return None def get_cluster_id(self, host: str) -> Optional[str]: - if host is None or not host.strip(): + prepared_host = self._get_prepared_host(host) + if prepared_host is None or not prepared_host.strip(): return None - if self._get_dns_group(host) is not None: - return self._get_group(host, self.INSTANCE_GROUP) + if self._get_dns_group(prepared_host) is not None: + return self._get_group(prepared_host, self.INSTANCE_GROUP) return None def get_instance_id(self, host: str) -> Optional[str]: - if self._get_dns_group(host) is None: - return self._get_group(host, self.INSTANCE_GROUP) + prepared_host = self._get_prepared_host(host) + if self._get_dns_group(prepared_host) is None: + return self._get_group(prepared_host, self.INSTANCE_GROUP) return None @@ -248,53 +273,61 @@ def is_dns_pattern_valid(self, host: str) -> bool: return "?" in host def identify_rds_type(self, host: Optional[str]) -> RdsUrlType: - if host is None or not host.strip(): + prepared_host = self._get_prepared_host(host) + if prepared_host is None or not prepared_host.strip(): return RdsUrlType.OTHER - if self.is_ip(host): + if self.is_ip(prepared_host): return RdsUrlType.IP_ADDRESS - elif self.is_global_db_writer_cluster_dns(host): + elif self.is_global_db_writer_cluster_dns(prepared_host): return RdsUrlType.RDS_GLOBAL_WRITER_CLUSTER - elif self.is_writer_cluster_dns(host): + elif self.is_writer_cluster_dns(prepared_host): return RdsUrlType.RDS_WRITER_CLUSTER - elif self.is_reader_cluster_dns(host): + elif self.is_reader_cluster_dns(prepared_host): return RdsUrlType.RDS_READER_CLUSTER - elif self.is_limitless_database_shard_group_dns(host): + elif self.is_limitless_database_shard_group_dns(prepared_host): return RdsUrlType.RDS_AURORA_LIMITLESS_DB_SHARD_GROUP - elif self.is_rds_custom_cluster_dns(host): + elif self.is_rds_custom_cluster_dns(prepared_host): return RdsUrlType.RDS_CUSTOM_CLUSTER - elif self.is_rds_proxy_dns(host): + elif self.is_rds_proxy_dns(prepared_host): return RdsUrlType.RDS_PROXY - elif self.is_rds_instance(host): + elif self.is_rds_instance(prepared_host): return RdsUrlType.RDS_INSTANCE return RdsUrlType.OTHER def is_green_instance(self, host: str) -> bool: - if not host: + prepared_host = self._get_prepared_host(host) + if not prepared_host: return False - return search(RdsUtils.BG_GREEN_HOST_PATTERN, host) is not None + return search(RdsUtils.BG_GREEN_HOST_PATTERN, prepared_host) is not None def is_not_old_instance(self, host: str) -> bool: - if host is None or not host.strip(): + prepared_host = self._get_prepared_host(host) + if prepared_host is None or not prepared_host.strip(): return False - return search(RdsUtils.BG_OLD_HOST_PATTERN, host) is None + return search(RdsUtils.BG_OLD_HOST_PATTERN, prepared_host) is None def is_not_green_or_old_instance(self, host: str) -> bool: - if not host: + prepared_host = self._get_prepared_host(host) + if not prepared_host: return False - return search(RdsUtils.BG_GREEN_HOST_PATTERN, host) is None and \ - search(RdsUtils.BG_OLD_HOST_PATTERN, host) is None + return search(RdsUtils.BG_GREEN_HOST_PATTERN, prepared_host) is None and \ + search(RdsUtils.BG_OLD_HOST_PATTERN, prepared_host) is None def remove_green_instance_prefix(self, host: str) -> str: if not host: return host - host_match = search(RdsUtils.BG_GREEN_HOST_PATTERN, host) + prepared_host = self._get_prepared_host(host) + if not prepared_host: + return host + + host_match = search(RdsUtils.BG_GREEN_HOST_PATTERN, prepared_host) if host_match is None: - host_id_match = search(RdsUtils.BG_GREEN_HOST_ID_PATTERN, host) + host_id_match = search(RdsUtils.BG_GREEN_HOST_ID_PATTERN, prepared_host) if host_id_match: return host_id_match.group(0) else: @@ -306,7 +339,7 @@ def remove_green_instance_prefix(self, host: str) -> str: return host.replace(f"{prefix}.", ".") - def _find(self, host: str, patterns: list): + def _find(self, host: Optional[str], patterns: list): if not host or not host.strip(): return None @@ -327,7 +360,7 @@ def _get_regex_group(self, pattern: Match[str], group_name: str): return None return pattern.group(group_name) - def _get_group(self, host: str, group: str): + def _get_group(self, host: Optional[str], group: str): if not host or not host.strip(): return None @@ -337,7 +370,7 @@ def _get_group(self, host: str, group: str): RdsUtils.AURORA_GOV_DNS_PATTERN]) return self._get_regex_group(pattern, group) - def _get_dns_group(self, host: str): + def _get_dns_group(self, host: Optional[str]): return self._get_group(host, RdsUtils.DNS_GROUP) def remove_port(self, url: str): diff --git a/docs/using-the-python-wrapper/UsingThePythonWrapper.md b/docs/using-the-python-wrapper/UsingThePythonWrapper.md index 7c6e13f7f..69dd0bcbe 100644 --- a/docs/using-the-python-wrapper/UsingThePythonWrapper.md +++ b/docs/using-the-python-wrapper/UsingThePythonWrapper.md @@ -102,6 +102,7 @@ The AWS Advanced Python Wrapper has several built-in plugins that are available | [Aurora Connection Tracker Plugin](./using-plugins/UsingTheAuroraConnectionTrackerPlugin.md) | `aurora_connection_tracker` | Aurora | Tracks all the opened connections. In the event of a cluster failover, the plugin will close all the impacted connections to the host. This plugin is enabled by default. | None | | [Read Write Splitting Plugin](./using-plugins/UsingTheReadWriteSplittingPlugin.md) | `read_write_splitting` | Aurora | Enables read write splitting functionality where users can switch between database reader and writer instances. | None | | [Simple Read Write Splitting Plugin](./using-plugins/UsingTheSimpleReadWriteSplittingPlugin.md) | `srw` | Any database | Enables read write splitting functionality where users can switch between reader and writer endpoints. | None | +| [GDB Read Write Splitting Plugin](./using-plugins/UsingTheGdbReadWriteSplittingPlugin.md) | `gdb_rw` | Global Database | Extends the Read Write Splitting Plugin with home-region awareness for Aurora Global Databases. Allows specifying a home region and constraining new connections to it. | None | | [Fastest Response Strategy Plugin](./using-plugins/UsingTheFastestResponseStrategyPlugin.md) | `fastest_response_strategy` | Aurora | A host selection strategy plugin that uses a host monitoring service to monitor each reader host's response time and choose the host with the fastest response. | None | | [Blue/Green Deployment Plugin](./using-plugins/UsingTheBlueGreenPlugin.md) | `bg` | Aurora,
RDS Instance | Enables client-side Blue/Green Deployment support. | None | | [Limitless Plugin](./using-plugins/UsingTheLimitlessPlugin.md) | `limitless` | Aurora | Enables client-side load-balancing of Transaction Routers on Amazon Aurora Limitless Databases. | None | diff --git a/docs/using-the-python-wrapper/using-plugins/UsingTheGdbReadWriteSplittingPlugin.md b/docs/using-the-python-wrapper/using-plugins/UsingTheGdbReadWriteSplittingPlugin.md new file mode 100644 index 000000000..9b30a44af --- /dev/null +++ b/docs/using-the-python-wrapper/using-plugins/UsingTheGdbReadWriteSplittingPlugin.md @@ -0,0 +1,52 @@ +# Global Database (GDB) Read/Write Splitting Plugin + +The GDB Read/Write Splitting Plugin extends the functionality of the [Read/Write Splitting Plugin](./UsingTheReadWriteSplittingPlugin.md) and adds settings tailored to Aurora Global Databases. + +The plugin introduces the concept of a *home region* and lets you constrain new connections to that region. Such restrictions are helpful in environments where remote AWS regions add latency that cannot be tolerated. + +Unless otherwise stated, all recommendations, configurations, and code examples for the [Read/Write Splitting Plugin](./UsingTheReadWriteSplittingPlugin.md) apply to the GDB Read/Write Splitting Plugin. + +## Loading the GDB Read/Write Splitting Plugin + +The GDB Read/Write Splitting Plugin is not loaded by default. To load it, include `gdb_rw` in the [`plugins`](../UsingThePythonWrapper.md#connection-plugin-manager-parameters) connection parameter. If you want to load it alongside the failover and host monitoring plugins, the GDB Read/Write Splitting Plugin **must be listed before** these plugins in the plugin chain so that failover errors are processed correctly. The wrapper sorts plugins by default, but the relative ordering still matters when sorting is disabled. + +```python +params = { + "plugins": "gdb_rw,failover_v2,host_monitoring_v2", + # Add other connection properties below... +} + +# If using MySQL: +conn = AwsWrapperConnection.connect(mysql.connector.connect, **params) + +# If using Postgres: +conn = AwsWrapperConnection.connect(psycopg.Connection.connect, **params) +``` + +If you would like to use the GDB Read/Write Splitting Plugin without the failover plugin, list `gdb_rw` on its own: + +```python +params = { + "plugins": "gdb_rw", + # Add other connection properties below... +} +``` + +> [!WARNING] +> Do not use the `read_write_splitting`, `srw`, and/or `gdb_rw` plugins (or any combination of them) at the same time on the same connection. + +## Using the GDB Read/Write Splitting Plugin against non-GDB clusters + +The GDB Read/Write Splitting Plugin can be used against Aurora and RDS clusters. However, since these cluster types are single-region, configuring a home region adds little value. + +## Configuration Parameters + +| Parameter | Value | Required | Description | Default Value | +|----------------------------------------|:-------:|:------------------------------------------------------------------------------------------------------------------------------------------:|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------| +| `reader_host_selector_strategy` | String | No | The name of the strategy used to select a new reader host. For more information on the available reader selection strategies, see the [Read/Write Splitting Plugin](./UsingTheReadWriteSplittingPlugin.md#connection-strategies) docs. | `random` | +| `gdb_rw_home_region` | String | If connecting using an IP address, a custom domain URL, a Global Database endpoint, or any other endpoint without a region: Yes

Otherwise: No | Defines the home region.

Examples: `us-west-2`, `us-east-1`.

If this parameter is omitted, the value is parsed from the connection URL. For regional cluster endpoints and instance endpoints, it is set to the region of the provided endpoint. If the endpoint has no region (for example, a Global Database endpoint or an IP address), the configuration parameter is mandatory. | For regional cluster endpoints and instance endpoints, it is set to the region of the provided endpoint.

Otherwise: `None` | +| `gdb_rw_restrict_writer_to_home_region`| Boolean | No | If set to `True`, prevents following or connecting to a writer node outside the defined home region. An exception will be raised when such a connection to a writer outside the home region is requested. | `True` | +| `gdb_rw_restrict_reader_to_home_region`| Boolean | No | If set to `True`, prevents connecting to a reader node outside the defined home region. If no reader nodes in the home region are available, an exception will be raised. | `True` | +| `gdb_enable_global_write_forwarding` | Boolean | No | If set to `True`, allows connections in the secondary region to forward write queries to the primary global region. This is useful when your home region is the secondary global region. This functionality requires [Global Write Forwarding](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-global-database-write-forwarding.html) to be enabled on the cluster. | `False` | + +Refer to the [Read/Write Splitting Plugin](./UsingTheReadWriteSplittingPlugin.md) docs for more details on error handling, connection pooling, and sample code; all of that information applies here. diff --git a/tests/integration/container/django/test_django_plugins.py b/tests/integration/container/django/test_django_plugins.py index 826543144..4cf80abae 100644 --- a/tests/integration/container/django/test_django_plugins.py +++ b/tests/integration/container/django/test_django_plugins.py @@ -38,6 +38,7 @@ enable_on_num_instances) from ..utils.database_engine import DatabaseEngine from ..utils.database_engine_deployment import DatabaseEngineDeployment +from ..utils.retry_helper import retry_until from ..utils.test_environment import TestEnvironment from ..utils.test_environment_features import TestEnvironmentFeatures @@ -609,7 +610,8 @@ def test_django_failover_during_query(self, test_environment: TestEnvironment, d with connection.cursor() as cursor: cursor.execute(RdsTestUtility.get_instance_id_query()) current_writer_id = cursor.fetchone()[0] - assert rds_utils.is_db_instance_writer(current_writer_id) is True + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: rds_utils.is_db_instance_writer(current_writer_id)) assert current_writer_id != initial_writer_id, "Should be connected to a new writer after failover" # Clean up test data @@ -670,7 +672,8 @@ def test_django_custom_endpoint_failover_during_query( with connection.cursor() as cursor: cursor.execute(RdsTestUtility.get_instance_id_query()) current_writer_id = cursor.fetchone()[0] - assert rds_utils.is_db_instance_writer(current_writer_id) is True + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: rds_utils.is_db_instance_writer(current_writer_id)) assert current_writer_id != initial_writer_id, "Should be connected to a new writer after failover" # Clean up test data diff --git a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py index b5419b026..513ab4116 100644 --- a/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py +++ b/tests/integration/container/sqlalchemy/test_sqlalchemy_plugins.py @@ -42,6 +42,7 @@ enable_on_num_instances) from ..utils.database_engine import DatabaseEngine from ..utils.database_engine_deployment import DatabaseEngineDeployment +from ..utils.retry_helper import retry_until from ..utils.test_environment import TestEnvironment from ..utils.test_environment_features import TestEnvironmentFeatures @@ -512,7 +513,8 @@ def test_sqlalchemy_failover_during_query(self, test_environment: TestEnvironmen current_writer_id = row._tuple()[0] else: raise Exception("Failed to get current_writer_id from row because row was None.") - assert rds_utils.is_db_instance_writer(current_writer_id) is True + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: rds_utils.is_db_instance_writer(current_writer_id)) assert current_writer_id != initial_writer_id session.query(TestModel).delete() diff --git a/tests/integration/container/test_aurora_failover.py b/tests/integration/container/test_aurora_failover.py index 18c25c9fa..23fb38bad 100644 --- a/tests/integration/container/test_aurora_failover.py +++ b/tests/integration/container/test_aurora_failover.py @@ -38,6 +38,7 @@ from aws_advanced_python_wrapper.wrapper import AwsWrapperConnection from .utils.driver_helper import DriverHelper from .utils.rds_test_utility import RdsTestUtility +from .utils.retry_helper import retry_until from .utils.test_environment import TestEnvironment from .utils.test_environment_features import TestEnvironmentFeatures @@ -116,7 +117,8 @@ def test_fail_from_writer_to_new_writer_fail_on_connection_invocation( # assert that we are connected to the new writer after failover happens. current_connection_id = aurora_utility.query_instance_id(aws_conn) - assert aurora_utility.is_db_instance_writer(current_connection_id) is True + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: aurora_utility.is_db_instance_writer(current_connection_id)) assert current_connection_id != initial_writer_id @pytest.mark.parametrize("plugins", ["failover", "failover_v2"]) @@ -137,7 +139,8 @@ def test_fail_from_writer_to_new_writer_fail_on_connection_bound_object_invocati # assert that we are connected to the new writer after failover happens and we can reuse the cursor current_connection_id = aurora_utility.query_instance_id(aws_conn) - assert aurora_utility.is_db_instance_writer(current_connection_id) is True + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: aurora_utility.is_db_instance_writer(current_connection_id)) assert current_connection_id != initial_writer_id @pytest.mark.parametrize("plugins", ["failover,host_monitoring", "failover,host_monitoring_v2", @@ -166,7 +169,8 @@ def test_fail_from_reader_to_writer( current_connection_id = aurora_utility.query_instance_id(aws_conn) assert writer_id == current_connection_id - assert aurora_utility.is_db_instance_writer(current_connection_id) is True + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: aurora_utility.is_db_instance_writer(current_connection_id)) @pytest.mark.parametrize("plugins", ["failover", "failover_v2"]) @enable_on_features([TestEnvironmentFeatures.FAILOVER_SUPPORTED]) @@ -195,7 +199,8 @@ def test_fail_from_writer_with_session_states_autocommit(self, test_driver: Test # Attempt to query the instance id. current_connection_id = aurora_utility.query_instance_id(conn) # Assert that we are connected to the new writer after failover happens. - assert aurora_utility.is_db_instance_writer(current_connection_id) is True + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: aurora_utility.is_db_instance_writer(current_connection_id)) next_cluster_writer_id = aurora_utility.get_cluster_writer_instance_id() assert current_connection_id == next_cluster_writer_id assert current_connection_id != initial_writer_id @@ -229,7 +234,8 @@ def test_fail_from_writer_with_session_states_readonly(self, test_driver: TestDr # Attempt to query the instance id. current_connection_id = aurora_utility.query_instance_id(conn) # Assert that we are connected to the new writer after failover happens. - assert aurora_utility.is_db_instance_writer(current_connection_id) is True + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: aurora_utility.is_db_instance_writer(current_connection_id)) next_cluster_writer_id = aurora_utility.get_cluster_writer_instance_id() assert current_connection_id == next_cluster_writer_id assert current_connection_id != initial_writer_id @@ -265,7 +271,8 @@ def test_writer_fail_within_transaction_set_autocommit_false( current_connection_id: str = aurora_utility.query_instance_id(conn) # assert that we are connected to the new writer after failover happens - assert aurora_utility.is_db_instance_writer(current_connection_id) + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: aurora_utility.is_db_instance_writer(current_connection_id)) next_cluster_writer_id: str = aurora_utility.get_cluster_writer_instance_id() assert current_connection_id == next_cluster_writer_id @@ -311,7 +318,8 @@ def test_writer_fail_within_transaction_start_transaction( current_connection_id: str = aurora_utility.query_instance_id(conn) # assert that we are connected to the new writer after failover happens - assert aurora_utility.is_db_instance_writer(current_connection_id) + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: aurora_utility.is_db_instance_writer(current_connection_id)) next_cluster_writer_id: str = aurora_utility.get_cluster_writer_instance_id() assert current_connection_id == next_cluster_writer_id diff --git a/tests/integration/container/test_aws_secrets_manager.py b/tests/integration/container/test_aws_secrets_manager.py index 1a26bcc8f..76225b238 100644 --- a/tests/integration/container/test_aws_secrets_manager.py +++ b/tests/integration/container/test_aws_secrets_manager.py @@ -28,6 +28,7 @@ from .utils.database_engine_deployment import DatabaseEngineDeployment from .utils.driver_helper import DriverHelper from .utils.rds_test_utility import RdsTestUtility +from .utils.retry_helper import retry_until from .utils.test_environment import TestEnvironment from .utils.test_environment_features import TestEnvironmentFeatures @@ -231,7 +232,8 @@ def test_failover_with_secrets_manager( aurora_utility.assert_first_query_throws(aws_conn, FailoverSuccessError) current_connection_id = aurora_utility.query_instance_id(aws_conn) - assert aurora_utility.is_db_instance_writer(current_connection_id) is True + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: aurora_utility.is_db_instance_writer(current_connection_id)) assert current_connection_id != initial_writer_id def validate_connection(self, target_driver_connect: Callable, **connect_params): diff --git a/tests/integration/container/test_custom_endpoint.py b/tests/integration/container/test_custom_endpoint.py index 625e4ae86..ef7877da5 100644 --- a/tests/integration/container/test_custom_endpoint.py +++ b/tests/integration/container/test_custom_endpoint.py @@ -39,6 +39,7 @@ DatabaseEngineDeployment from tests.integration.container.utils.driver_helper import DriverHelper from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.retry_helper import retry_until from tests.integration.container.utils.test_environment import TestEnvironment from tests.integration.container.utils.test_environment_features import \ TestEnvironmentFeatures @@ -327,6 +328,10 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit with pytest.raises(ReadWriteSplittingError): conn.read_only = False + # The RDS API lags behind the writer election triggered during setup, so it may still report + # the previous writer (now the reader we are connected to). Retry until the API reflects a + # writer distinct from our reader, otherwise StaticMembers would contain duplicate ids. + assert retry_until(lambda: rds_utils.get_cluster_writer_instance_id() != original_reader_id) writer_id = rds_utils.get_cluster_writer_instance_id() rds_client = client('rds', region_name=TestEnvironment.get_current().get_aurora_region()) @@ -398,6 +403,10 @@ def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes__wit assert new_instance_id == original_writer_id instances = TestEnvironment.get_current().get_instances() + # The RDS API lags behind the writer election triggered during setup, so it may still report + # a stale writer. Retry until the API agrees the instance we are connected to is the writer, + # otherwise the reader selection below could pick our own writer and create duplicate ids. + assert retry_until(lambda: rds_utils.get_cluster_writer_instance_id() == original_writer_id) writer_id = str(rds_utils.get_cluster_writer_instance_id()) reader_id_to_add = "" diff --git a/tests/integration/container/test_iam_authentication.py b/tests/integration/container/test_iam_authentication.py index c83f70a2e..6b92f62e8 100644 --- a/tests/integration/container/test_iam_authentication.py +++ b/tests/integration/container/test_iam_authentication.py @@ -40,6 +40,7 @@ DatabaseEngineDeployment from tests.integration.container.utils.driver_helper import DriverHelper from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.retry_helper import retry_until from tests.integration.container.utils.test_environment import TestEnvironment @@ -164,7 +165,8 @@ def test_failover_with_iam( # assert that we are connected to the new writer after failover happens and we can reuse the cursor current_connection_id = aurora_utility.query_instance_id(aws_conn) - assert aurora_utility.is_db_instance_writer(current_connection_id) is True + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: aurora_utility.is_db_instance_writer(current_connection_id)) assert current_connection_id != initial_writer_id def get_ip_address(self, hostname: str): diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index 931b17c8a..82e25f2fd 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -23,12 +23,14 @@ from aws_advanced_python_wrapper.errors import ( AwsWrapperError, FailoverFailedError, FailoverSuccessError, ReadWriteSplittingError, TransactionResolutionUnknownError) +from aws_advanced_python_wrapper.plugin_service import PluginServiceImpl from aws_advanced_python_wrapper.sql_alchemy_connection_provider import \ SqlAlchemyPooledConnectionProvider from aws_advanced_python_wrapper.utils import services_container from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_utils import RdsUtils from tests.integration.container.utils.conditions import ( disable_on_engines, disable_on_features, enable_on_deployments, enable_on_features, enable_on_num_instances) @@ -38,6 +40,7 @@ from tests.integration.container.utils.driver_helper import DriverHelper from tests.integration.container.utils.proxy_helper import ProxyHelper from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.retry_helper import retry_until from tests.integration.container.utils.test_driver import TestDriver from tests.integration.container.utils.test_environment import TestEnvironment from tests.integration.container.utils.test_environment_features import \ @@ -64,16 +67,37 @@ def setup_method(self, request): release_resources() gc.collect() + @pytest.fixture(autouse=True) + def configure_prepare_host_func(self): + proxied_suffix = ".proxied" + + def strip_proxied(host: str) -> str: + if host and host.endswith(proxied_suffix): + return host[: -len(proxied_suffix)] + return host + + RdsUtils.set_prepare_host_func(strip_proxied) + RdsUtils.clear_cache() + try: + yield + finally: + RdsUtils.reset_prepare_host_func() + RdsUtils.clear_cache() + # Plugin configurations @pytest.fixture( - params=[("read_write_splitting", "read_write_splitting"), ("srw", "srw")] + params=[ + ("read_write_splitting", "read_write_splitting"), + ("srw", "srw"), + ("gdb_rw", "gdb_rw"), + ] ) def plugin_config(self, request): return request.param @pytest.fixture(scope="class") def rds_utils(self): - region: str = TestEnvironment.get_current().get_info().get_region() + region: str = TestEnvironment.get_current().get_aurora_region() return RdsTestUtility(region) @pytest.fixture(autouse=True) @@ -94,6 +118,7 @@ def props(self, plugin_config, conn_utils): "socket_timeout": 10, "connect_timeout": 10, "autocommit": True, + "gdb_rw_home_region": TestEnvironment.get_current().get_aurora_region() } ) @@ -135,7 +160,8 @@ def failover_props(self, plugin_config, conn_utils): "socket_timeout": 10, "connect_timeout": 10, "autocommit": True, - "cluster_id": "cluster1" + "cluster_id": "cluster1", + "gdb_rw_home_region": TestEnvironment.get_current().get_aurora_region() } # Add simple plugin specific configuration if plugin_name == "srw": @@ -239,9 +265,9 @@ def test_connect_to_reader__switch_read_only( plugin_config, ): plugin_name, _ = plugin_config - if plugin_name != "read_write_splitting": + if plugin_name not in ("read_write_splitting", "gdb_rw"): pytest.skip( - "Test only applies to read_write_splitting plugin: srw does not connect to instances" + "Test only applies to topology-based read/write splitting plugins: srw does not connect to instances" ) target_driver_connect = DriverHelper.get_connect_func(test_driver) reader_instance = test_environment.get_instances()[1] @@ -511,7 +537,10 @@ def test_failover_to_new_writer__switch_read_only( new_writer_id = rds_utils.query_instance_id(conn) assert original_writer_id != new_writer_id - assert rds_utils.is_db_instance_writer(new_writer_id) + # RDS API lags behind the writer election, so we retry the check. + assert retry_until(lambda: rds_utils.is_db_instance_writer(new_writer_id)) + + PluginServiceImpl._host_availability_expiring_cache.clear() conn.read_only = True current_id = rds_utils.query_instance_id(conn) @@ -521,7 +550,15 @@ def test_failover_to_new_writer__switch_read_only( current_id = rds_utils.query_instance_id(conn) assert new_writer_id == current_id - @pytest.mark.parametrize("plugins", ["read_write_splitting,failover,host_monitoring", "read_write_splitting,failover,host_monitoring_v2"]) + @pytest.mark.parametrize( + "plugins", + [ + "read_write_splitting,failover,host_monitoring", + "read_write_splitting,failover,host_monitoring_v2", + "gdb_rw,failover,host_monitoring", + "gdb_rw,failover,host_monitoring_v2", + ], + ) @enable_on_features([TestEnvironmentFeatures.NETWORK_OUTAGES_ENABLED, TestEnvironmentFeatures.ABORT_CONNECTION_SUPPORTED]) @enable_on_num_instances(min_instances=3) @@ -537,10 +574,10 @@ def test_failover_to_new_reader__switch_read_only( plugins ): plugin_name, _ = plugin_config - if plugin_name != "read_write_splitting": + if plugin_name not in ("read_write_splitting", "gdb_rw"): # Disabling the reader connection in srw, the srwReadEndpoint, results in defaulting to the writer not connecting to another reader. pytest.skip( - "Test only applies to read_write_splitting plugin: reader connection failover" + "Test only applies to topology-based read/write splitting plugins: reader connection failover" ) WrapperProperties.FAILOVER_MODE.set(proxied_failover_props, "reader-or-writer") @@ -1013,9 +1050,9 @@ def test_pooled_connection__least_connections( plugin_config, ): plugin_name, _ = plugin_config - if plugin_name != "read_write_splitting": + if plugin_name not in ("read_write_splitting", "gdb_rw"): pytest.skip( - "Test only applies to read_write_splitting plugin: reader host selector strategy" + "Test only applies to topology-based read/write splitting plugins: reader host selector strategy" ) WrapperProperties.READER_HOST_SELECTOR_STRATEGY.set(props, "least_connections") @@ -1062,9 +1099,9 @@ def test_pooled_connection__least_connections__pool_mapping( plugin_config, ): plugin_name, _ = plugin_config - if plugin_name != "read_write_splitting": + if plugin_name not in ("read_write_splitting", "gdb_rw"): pytest.skip( - "Test only applies to read_write_splitting plugin: reader host selector strategy" + "Test only applies to topology-based read/write splitting plugins: reader host selector strategy" ) WrapperProperties.READER_HOST_SELECTOR_STRATEGY.set(props, "least_connections") diff --git a/tests/integration/container/utils/retry_helper.py b/tests/integration/container/utils/retry_helper.py new file mode 100644 index 000000000..2d7f2e70e --- /dev/null +++ b/tests/integration/container/utils/retry_helper.py @@ -0,0 +1,54 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import timeit +from time import sleep +from typing import TYPE_CHECKING, Callable + +from aws_advanced_python_wrapper.utils.log import Logger + +if TYPE_CHECKING: + from .rds_test_utility import RdsTestUtility + +logger = Logger(__name__) + +_DEFAULT_TIMEOUT_SEC: float = 5 * 60 # 5 minutes +_DEFAULT_DELAY_SEC: float = 5 # 5 seconds + + +def retry_until( + condition: Callable[[], bool], + timeout_sec: float = _DEFAULT_TIMEOUT_SEC, + delay_sec: float = _DEFAULT_DELAY_SEC) -> bool: + deadline = timeit.default_timer() + timeout_sec + + while timeit.default_timer() < deadline: + if condition(): + return True + sleep(delay_sec) + + return False + + +def verify_writer(aurora_util: RdsTestUtility, expected_writer_id: str) -> bool: + logger.debug("Expected writer (API): " + expected_writer_id) + + def _condition() -> bool: + api_writer_id = aurora_util.get_cluster_writer_instance_id() + logger.debug("Writer (API): " + api_writer_id) + return expected_writer_id.lower() == api_writer_id.lower() + + return retry_until(_condition) diff --git a/tests/unit/test_rds_utils.py b/tests/unit/test_rds_utils.py index 3858add7d..a75bcb219 100644 --- a/tests/unit/test_rds_utils.py +++ b/tests/unit/test_rds_utils.py @@ -185,6 +185,7 @@ def test_get_rds_instance_host_pattern(expected, test_value): ("us-isob-east-1", us_isob_east_region_custom_domain), ("us-isob-east-1", us_isob_east_region_limitless_db_shard_group), ("us-gov-east-1", us_gov_east_region_cluster), + ("us-east-2", us_east_region_elb_url), ]) def test_get_rds_region(expected, test_value): target = RdsUtils()