diff --git a/CHANGELOG.md b/CHANGELOG.md index f48ab6645..938a839e9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -3,6 +3,9 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/#semantic-versioning-200). +### :magic_wand: Added +- Custom Endpoint Plugin. See [UsingTheCustomEndpointPlugin.md](https://github.com/aws/aws-advanced-jdbc-wrapper/blob/main/docs/using-the-jdbc-driver/using-plugins/UsingTheCustomEndpointPlugin.md). + ### :bug: Fixed - Use the cluster URL as the default cluster ID ([PR #1131](https://github.com/aws/aws-advanced-jdbc-wrapper/pull/1131)). - Fix logic in SlidingExpirationCache and SlidingExpirationCacheWithCleanupThread ([PR #1142](https://github.com/aws/aws-advanced-jdbc-wrapper/pull/1142)). diff --git a/wrapper/src/main/java/software/amazon/jdbc/AllowedAndBlockedHosts.java b/wrapper/src/main/java/software/amazon/jdbc/AllowedAndBlockedHosts.java new file mode 100644 index 000000000..6941de9b8 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/AllowedAndBlockedHosts.java @@ -0,0 +1,66 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc; + +import java.util.Collections; +import java.util.Set; +import org.checkerframework.checker.nullness.qual.Nullable; +import software.amazon.jdbc.util.Utils; + +/** + * Represents the allowed and blocked hosts for connections. + */ +public class AllowedAndBlockedHosts { + @Nullable private final Set allowedHostIds; + @Nullable private final Set blockedHostIds; + + /** + * Constructs an AllowedAndBlockedHosts instance with the specified allowed and blocked host IDs. + * + * @param allowedHostIds The set of allowed host IDs for connections. If null or empty, all host IDs that are not in + * {@code blockedHostIds} are allowed. + * @param blockedHostIds The set of blocked host IDs for connections. If null or empty, all host IDs in + * {@code allowedHostIds} are allowed. If {@code allowedHostIds} is also null or empty, there + * are no restrictions on which hosts are allowed. + */ + public AllowedAndBlockedHosts(@Nullable Set allowedHostIds, @Nullable Set blockedHostIds) { + this.allowedHostIds = Utils.isNullOrEmpty(allowedHostIds) ? null : Collections.unmodifiableSet(allowedHostIds); + this.blockedHostIds = Utils.isNullOrEmpty(blockedHostIds) ? null : Collections.unmodifiableSet(blockedHostIds); + } + + /** + * Returns the set of allowed host IDs for connections. If null or empty, all host IDs that are not in + * {@code blockedHostIds} are allowed. + * + * @return the set of allowed host IDs for connections. + */ + @Nullable + public Set getAllowedHostIds() { + return this.allowedHostIds; + } + + /** + * Returns the set of blocked host IDs for connections. If null or empty, all host IDs in {@code allowedHostIds} are + * allowed. If {@code allowedHostIds} is also null or empty, there are no restrictions on which hosts are allowed. + * + * @return the set of blocked host IDs for connections. + */ + @Nullable + public Set getBlockedHostIds() { + return this.blockedHostIds; + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 0b701a9c7..6b41835f8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -35,6 +35,7 @@ import software.amazon.jdbc.plugin.DriverMetaDataConnectionPluginFactory; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPluginFactory; import software.amazon.jdbc.plugin.LogQueryConnectionPluginFactory; +import software.amazon.jdbc.plugin.customendpoint.CustomEndpointPluginFactory; import software.amazon.jdbc.plugin.dev.DeveloperConnectionPluginFactory; import software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPluginFactory; import software.amazon.jdbc.plugin.failover.FailoverConnectionPluginFactory; @@ -63,6 +64,7 @@ public class ConnectionPluginChainBuilder { put("executionTime", ExecutionTimeConnectionPluginFactory.class); put("logQuery", LogQueryConnectionPluginFactory.class); put("dataCache", DataCacheConnectionPluginFactory.class); + put("customEndpoint", CustomEndpointPluginFactory.class); put("efm", HostMonitoringConnectionPluginFactory.class); put("efm2", software.amazon.jdbc.plugin.efm2.HostMonitoringConnectionPluginFactory.class); put("failover", FailoverConnectionPluginFactory.class); @@ -93,6 +95,7 @@ public class ConnectionPluginChainBuilder { { put(DriverMetaDataConnectionPluginFactory.class, 100); put(DataCacheConnectionPluginFactory.class, 200); + put(CustomEndpointPluginFactory.class, 380); put(AuroraInitialConnectionStrategyPluginFactory.class, 390); put(AuroraConnectionTrackerPluginFactory.class, 400); put(AuroraStaleDnsPluginFactory.class, 500); diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java index 467f1a5df..f0a42e85e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginManager.java @@ -37,6 +37,7 @@ import software.amazon.jdbc.plugin.DefaultConnectionPlugin; import software.amazon.jdbc.plugin.ExecutionTimeConnectionPlugin; import software.amazon.jdbc.plugin.LogQueryConnectionPlugin; +import software.amazon.jdbc.plugin.customendpoint.CustomEndpointPlugin; import software.amazon.jdbc.plugin.efm.HostMonitoringConnectionPlugin; import software.amazon.jdbc.plugin.failover.FailoverConnectionPlugin; import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPlugin; @@ -50,6 +51,7 @@ import software.amazon.jdbc.util.AsynchronousMethodsHelper; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.SqlMethodAnalyzer; +import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -85,6 +87,7 @@ public class ConnectionPluginManager implements CanReleaseResources, Wrapper { put(FastestResponseStrategyPlugin.class, "plugin:fastestResponseStrategy"); put(DefaultConnectionPlugin.class, "plugin:targetDriver"); put(AuroraInitialConnectionStrategyPlugin.class, "plugin:initialConnection"); + put(CustomEndpointPlugin.class, "plugin:customEndpoint"); } }; @@ -493,7 +496,7 @@ public HostSpec getHostSpecByStrategy(List hosts, HostRole role, Strin if (isSubscribed) { try { - final HostSpec host = hosts == null || hosts.isEmpty() + final HostSpec host = Utils.isNullOrEmpty(hosts) ? plugin.getHostSpecByStrategy(role, strategy) : plugin.getHostSpecByStrategy(hosts, role, strategy); diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java index 05164ca12..c6e08a980 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginService.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginService.java @@ -62,10 +62,31 @@ EnumSet setCurrentConnection( @Nullable ConnectionPlugin skipNotificationForThisPlugin) throws SQLException; + /** + * Get host information for all hosts in the cluster. + * + * @return host information for all hosts in the cluster. + */ + List getAllHosts(); + + /** + * Get host information for allowed hosts in the cluster. Certain hosts in the cluster may be disallowed, and these + * hosts will not be returned by this function. For example, if a custom endpoint is being used, hosts outside the + * custom endpoint will not be returned. + * + * @return host information for allowed hosts in the cluster. + */ List getHosts(); HostSpec getInitialConnectionHostSpec(); + /** + * Set the collection of hosts that should be allowed and/or blocked for connections. + * + * @param allowedAndBlockedHosts An object defining the allowed and blocked sets of hosts. + */ + void setAllowedAndBlockedHosts(AllowedAndBlockedHosts allowedAndBlockedHosts); + /** * Returns a boolean indicating if the available {@link ConnectionProvider} or * {@link ConnectionPlugin} instances support the selection of a host with the requested role and diff --git a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java index 10b4d0486..887135f4e 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java +++ b/wrapper/src/main/java/software/amazon/jdbc/PluginServiceImpl.java @@ -31,6 +31,7 @@ import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.locks.ReentrantLock; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -52,6 +53,7 @@ import software.amazon.jdbc.targetdriverdialect.TargetDriverDialect; import software.amazon.jdbc.util.CacheMap; import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.telemetry.TelemetryFactory; public class PluginServiceImpl implements PluginService, CanReleaseResources, @@ -66,7 +68,8 @@ public class PluginServiceImpl implements PluginService, CanReleaseResources, private final String originalUrl; private final String driverProtocol; protected volatile HostListProvider hostListProvider; - protected List hosts = new ArrayList<>(); + protected List allHosts = new ArrayList<>(); + protected AtomicReference allowedAndBlockedHosts = new AtomicReference<>(); protected Connection currentConnection; protected HostSpec currentHostSpec; protected HostSpec initialConnectionHostSpec; @@ -162,10 +165,20 @@ public HostSpec getCurrentHostSpec() { this.currentHostSpec = this.initialConnectionHostSpec; if (this.currentHostSpec == null) { - if (this.getHosts().isEmpty()) { + if (this.getAllHosts().isEmpty()) { throw new RuntimeException(Messages.get("PluginServiceImpl.hostListEmpty")); } - this.currentHostSpec = this.getWriter(this.getHosts()); + + this.currentHostSpec = this.getWriter(this.getAllHosts()); + if (!this.getHosts().contains(this.currentHostSpec)) { + throw new RuntimeException( + Messages.get("PluginServiceImpl.currentHostNotAllowed", + new Object[] { + currentHostSpec == null ? "" : currentHostSpec.getHost(), + Utils.logTopology(this.getHosts(), "")}) + ); + } + if (this.currentHostSpec == null) { this.currentHostSpec = this.getHosts().get(0); } @@ -187,6 +200,11 @@ public HostSpec getInitialConnectionHostSpec() { return this.initialConnectionHostSpec; } + @Override + public void setAllowedAndBlockedHosts(AllowedAndBlockedHosts allowedAndBlockedHosts) { + this.allowedAndBlockedHosts.set(allowedAndBlockedHosts); + } + @Override public boolean acceptsStrategy(HostRole role, String strategy) throws SQLException { return this.pluginManager.acceptsStrategy(role, strategy); @@ -364,9 +382,35 @@ protected EnumSet compare( return changes; } + @Override + public List getAllHosts() { + return this.allHosts; + } + @Override public List getHosts() { - return this.hosts; + AllowedAndBlockedHosts hostPermissions = this.allowedAndBlockedHosts.get(); + if (hostPermissions == null) { + return this.allHosts; + } + + List hosts = this.allHosts; + Set allowedHostIds = hostPermissions.getAllowedHostIds(); + Set blockedHostIds = hostPermissions.getBlockedHostIds(); + + if (!Utils.isNullOrEmpty(allowedHostIds)) { + hosts = hosts.stream() + .filter((hostSpec -> allowedHostIds.contains(hostSpec.getHostId()))) + .collect(Collectors.toList()); + } + + if (!Utils.isNullOrEmpty(blockedHostIds)) { + hosts = hosts.stream() + .filter((hostSpec -> !blockedHostIds.contains(hostSpec.getHostId()))) + .collect(Collectors.toList()); + } + + return hosts; } @Override @@ -376,7 +420,7 @@ public void setAvailability(final @NonNull Set hostAliases, final @NonNu return; } - final List hostsToChange = this.getHosts().stream() + final List hostsToChange = this.getAllHosts().stream() .filter((host) -> hostAliases.contains(host.asAlias()) || host.getAliases().stream().anyMatch(hostAliases::contains)) .distinct() @@ -427,18 +471,18 @@ public HostListProvider getHostListProvider() { @Override public void refreshHostList() throws SQLException { final List updatedHostList = this.getHostListProvider().refresh(); - if (!Objects.equals(updatedHostList, this.hosts)) { + if (!Objects.equals(updatedHostList, this.allHosts)) { updateHostAvailability(updatedHostList); - setNodeList(this.hosts, updatedHostList); + setNodeList(this.allHosts, updatedHostList); } } @Override public void refreshHostList(final Connection connection) throws SQLException { final List updatedHostList = this.getHostListProvider().refresh(connection); - if (!Objects.equals(updatedHostList, this.hosts)) { + if (!Objects.equals(updatedHostList, this.allHosts)) { updateHostAvailability(updatedHostList); - setNodeList(this.hosts, updatedHostList); + setNodeList(this.allHosts, updatedHostList); } } @@ -447,7 +491,7 @@ public void forceRefreshHostList() throws SQLException { final List updatedHostList = this.getHostListProvider().forceRefresh(); if (updatedHostList != null) { updateHostAvailability(updatedHostList); - setNodeList(this.hosts, updatedHostList); + setNodeList(this.allHosts, updatedHostList); } } @@ -456,7 +500,7 @@ public void forceRefreshHostList(final Connection connection) throws SQLExceptio final List updatedHostList = this.getHostListProvider().forceRefresh(connection); if (updatedHostList != null) { updateHostAvailability(updatedHostList); - setNodeList(this.hosts, updatedHostList); + setNodeList(this.allHosts, updatedHostList); } } @@ -476,7 +520,7 @@ public boolean forceRefreshHostList(final boolean shouldVerifyWriter, final long ((BlockingHostListProvider) hostListProvider).forceRefresh(shouldVerifyWriter, timeoutMs); if (updatedHostList != null) { updateHostAvailability(updatedHostList); - setNodeList(this.hosts, updatedHostList); + setNodeList(this.allHosts, updatedHostList); return true; } } catch (TimeoutException ex) { @@ -520,7 +564,7 @@ void setNodeList(@Nullable final List oldHosts, } if (!changes.isEmpty()) { - this.hosts = newHosts != null ? newHosts : new ArrayList<>(); + this.allHosts = newHosts != null ? newHosts : new ArrayList<>(); this.pluginManager.notifyNodeListChanged(changes); } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java index c2d4dfc9a..98482b329 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java +++ b/wrapper/src/main/java/software/amazon/jdbc/hostlistprovider/RdsHostListProvider.java @@ -525,7 +525,7 @@ public List refresh(final Connection connection) throws SQLException { : this.hostListProviderService.getCurrentConnection(); final FetchTopologyResult results = getTopology(currentConnection, false); - LOGGER.finest(() -> Utils.logTopology(results.hosts, results.isCachedData ? "[From cache] " : "")); + LOGGER.finest(() -> Utils.logTopology(results.hosts, results.isCachedData ? "[From cache] Topology:" : null)); this.hostList = results.hosts; return Collections.unmodifiableList(hostList); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java index 7ea150f4c..3d7f0cdd4 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPlugin.java @@ -136,7 +136,7 @@ public T execute(final Class resultClass, final Clas } private void checkWriterChanged() { - final HostSpec hostSpecAfterFailover = this.getWriter(this.pluginService.getHosts()); + final HostSpec hostSpecAfterFailover = this.getWriter(this.pluginService.getAllHosts()); if (this.currentWriter == null) { this.currentWriter = hostSpecAfterFailover; @@ -153,7 +153,7 @@ private void checkWriterChanged() { private void rememberWriter() { if (this.currentWriter == null || this.needUpdateCurrentWriter) { - this.currentWriter = this.getWriter(this.pluginService.getHosts()); + this.currentWriter = this.getWriter(this.pluginService.getAllHosts()); this.needUpdateCurrentWriter = false; } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java index 15ad1cf0d..8dbbfe67c 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java @@ -351,7 +351,7 @@ private void delay(final long delayMs) { } private HostSpec getWriter() { - for (final HostSpec host : this.pluginService.getHosts()) { + for (final HostSpec host : this.pluginService.getAllHosts()) { if (host.getRole() == HostRole.WRITER) { return host; } @@ -380,12 +380,12 @@ private HostSpec getReader(final Properties props) throws SQLException { } private boolean hasNoReaders() { - if (this.pluginService.getHosts().isEmpty()) { + if (this.pluginService.getAllHosts().isEmpty()) { // Topology inconclusive/corrupted. return false; } - for (HostSpec hostSpec : this.pluginService.getHosts()) { + for (HostSpec hostSpec : this.pluginService.getAllHosts()) { if (hostSpec.getRole() == HostRole.WRITER) { continue; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java index 1cd437e08..4032588b7 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/AwsSecretsManagerConnectionPlugin.java @@ -50,6 +50,7 @@ import software.amazon.jdbc.PropertyDefinition; import software.amazon.jdbc.authentication.AwsCredentialsManager; import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; @@ -79,6 +80,7 @@ public class AwsSecretsManagerConnectionPlugin extends AbstractConnectionPlugin "secretsManagerEndpoint", null, "The endpoint of the secret to retrieve."); + protected static final RegionUtils regionUtils = new RegionUtils(); protected static final Map, Secret> secretsCache = new ConcurrentHashMap<>(); private static final Pattern SECRETS_ARN_PATTERN = @@ -156,27 +158,21 @@ public AwsSecretsManagerConnectionPlugin(final PluginService pluginService, fina new Object[] {SECRET_ID_PROPERTY.name})); } - String regionString; - if (StringUtils.isNullOrEmpty(props.getProperty(REGION_PROPERTY.name))) { + Region region = regionUtils.getRegion(props, REGION_PROPERTY.name); + if (region == null) { final Matcher matcher = SECRETS_ARN_PATTERN.matcher(secretId); if (matcher.matches()) { - regionString = matcher.group("region"); - } else { - throw new RuntimeException( - Messages.get( - "AwsSecretsManagerConnectionPlugin.missingRequiredConfigParameter", - new Object[] {REGION_PROPERTY.name})); + region = regionUtils.getRegionFromRegionString(matcher.group("region")); } - } else { - regionString = REGION_PROPERTY.getString(props); } - final Region region = Region.of(regionString); - if (!Region.regions().contains(region)) { - throw new RuntimeException(Messages.get( - "AwsSdk.unsupportedRegion", - new Object[] {regionString})); + if (region == null) { + throw new RuntimeException( + Messages.get( + "AwsSecretsManagerConnectionPlugin.missingRequiredConfigParameter", + new Object[] {REGION_PROPERTY.name})); } + this.secretKey = Pair.of(secretId, region); this.secretsManagerClientFunc = secretsManagerClientFunc; diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointInfo.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointInfo.java new file mode 100644 index 000000000..b4b80c837 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointInfo.java @@ -0,0 +1,209 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.customendpoint; + +import static software.amazon.jdbc.plugin.customendpoint.MemberListType.EXCLUSION_LIST; +import static software.amazon.jdbc.plugin.customendpoint.MemberListType.STATIC_LIST; + +import java.util.HashSet; +import java.util.List; +import java.util.Objects; +import java.util.Set; +import software.amazon.awssdk.services.rds.model.DBClusterEndpoint; + +/** + * Represents custom endpoint information for a given custom endpoint. + */ +public class CustomEndpointInfo { + private final String endpointIdentifier; // ID portion of the custom endpoint URL. + private final String clusterIdentifier; // ID of the cluster that the custom endpoint belongs to. + private final String url; + private final CustomEndpointRoleType roleType; + + // A given custom endpoint will either specify a static list or an exclusion list, as indicated by `memberListType`. + // If the list is a static list, 'members' specifies instances included in the custom endpoint, and new cluster + // instances will not be automatically added to the custom endpoint. If it is an exclusion list, 'members' specifies + // instances excluded by the custom endpoint, and new cluster instances will be added to the custom endpoint. + private final MemberListType memberListType; + private final Set members; + + /** + * Constructs a new CustomEndpointInfo instance with the specified details. + * + * @param endpointIdentifier The endpoint identifier for the custom endpoint. For example, if the custom endpoint URL + * is "my-custom-endpoint.cluster-custom-XYZ.us-east-1.rds.amazonaws.com", the endpoint + * identifier is "my-custom-endpoint". + * @param clusterIdentifier The cluster identifier for the cluster that the custom endpoint belongs to. + * @param url The URL for the custom endpoint. + * @param roleType The role type of the custom endpoint. + * @param members The instance IDs for the hosts in the custom endpoint. + * @param memberListType The list type for {@code members}. + */ + public CustomEndpointInfo( + String endpointIdentifier, + String clusterIdentifier, + String url, + CustomEndpointRoleType roleType, + Set members, + MemberListType memberListType) { + this.endpointIdentifier = endpointIdentifier; + this.clusterIdentifier = clusterIdentifier; + this.url = url; + this.roleType = roleType; + this.members = members; + this.memberListType = memberListType; + } + + /** + * Constructs a CustomEndpointInfo object from a DBClusterEndpoint instance as returned by the RDS API. + * + * @param responseEndpointInfo The endpoint info returned by the RDS API. + * @return a CustomEndPointInfo object representing the information in the given DBClusterEndpoint. + */ + public static CustomEndpointInfo fromDBClusterEndpoint(DBClusterEndpoint responseEndpointInfo) { + final List members; + final MemberListType memberListType; + + if (responseEndpointInfo.hasStaticMembers()) { + members = responseEndpointInfo.staticMembers(); + memberListType = MemberListType.STATIC_LIST; + } else { + members = responseEndpointInfo.excludedMembers(); + memberListType = MemberListType.EXCLUSION_LIST; + } + + return new CustomEndpointInfo( + responseEndpointInfo.dbClusterEndpointIdentifier(), + responseEndpointInfo.dbClusterIdentifier(), + responseEndpointInfo.endpoint(), + CustomEndpointRoleType.valueOf(responseEndpointInfo.customEndpointType()), + new HashSet<>(members), + memberListType + ); + } + + /** + * Gets the endpoint identifier for the custom endpoint. For example, if the custom endpoint URL is + * "my-custom-endpoint.cluster-custom-XYZ.us-east-1.rds.amazonaws.com", the endpoint identifier is + * "my-custom-endpoint". + * + * @return the endpoint identifier for the custom endpoint. + */ + public String getEndpointIdentifier() { + return endpointIdentifier; + } + + /** + * Gets the cluster identifier for the cluster that the custom endpoint belongs to. + * + * @return the cluster identifier for the cluster that the custom endpoint belongs to. + */ + public String getClusterIdentifier() { + return clusterIdentifier; + } + + /** + * Gets the URL for the custom endpoint. + * + * @return the URL for the custom endpoint. + */ + public String getUrl() { + return url; + } + + /** + * Gets the role type of the custom endpoint. + * + * @return the role type of the custom endpoint. + */ + public CustomEndpointRoleType getCustomEndpointType() { + return roleType; + } + + /** + * Gets the member list type of the custom endpoint. + * + * @return the member list type of the custom endpoint. + */ + public MemberListType getMemberListType() { + return this.memberListType; + } + + /** + * Gets the static members of the custom endpoint. If the custom endpoint member list type is an exclusion list, + * returns null. + * + * @return the static members of the custom endpoint, or null if the custom endpoint member list type is an exclusion + * list. + */ + public Set getStaticMembers() { + return STATIC_LIST.equals(this.memberListType) ? this.members : null; + } + + /** + * Gets the excluded members of the custom endpoint. If the custom endpoint member list type is a static list, + * returns null. + * + * @return the excluded members of the custom endpoint, or null if the custom endpoint member list type is a static + * list. + */ + public Set getExcludedMembers() { + return EXCLUSION_LIST.equals(this.memberListType) ? this.members : null; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + + if (obj == null) { + return false; + } + + if (getClass() != obj.getClass()) { + return false; + } + + CustomEndpointInfo info = (CustomEndpointInfo) obj; + return Objects.equals(this.endpointIdentifier, info.endpointIdentifier) + && Objects.equals(this.clusterIdentifier, info.clusterIdentifier) + && Objects.equals(this.url, info.url) + && Objects.equals(this.roleType, info.roleType) + && Objects.equals(this.members, info.members) + && Objects.equals(this.memberListType, info.memberListType); + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + ((this.endpointIdentifier == null) ? 0 : this.endpointIdentifier.hashCode()); + result = prime * result + ((this.clusterIdentifier == null) ? 0 : this.clusterIdentifier.hashCode()); + result = prime * result + ((this.url == null) ? 0 : this.url.hashCode()); + result = prime * result + ((this.roleType == null) ? 0 : this.roleType.hashCode()); + result = prime * result + ((this.memberListType == null) ? 0 : this.memberListType.hashCode()); + return result; + } + + @Override + public String toString() { + return String.format( + "CustomEndpointInfo[url=%s, clusterIdentifier=%s, customEndpointType=%s, memberListType=%s, members=%s]", + this.url, this.clusterIdentifier, this.roleType, this.memberListType, this.members); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointMonitor.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointMonitor.java new file mode 100644 index 000000000..fd8e82148 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointMonitor.java @@ -0,0 +1,39 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.customendpoint; + +/** + * Interface for custom endpoint monitors. Custom endpoint monitors analyze a given custom endpoint for custom endpoint + * information and future changes to the endpoint. + */ +public interface CustomEndpointMonitor extends AutoCloseable, Runnable { + + /** + * Evaluates whether the monitor should be disposed. + * + * @return true if the monitor should be disposed, otherwise return false. + */ + boolean shouldDispose(); + + /** + * Indicates whether the monitor has info about the custom endpoint or not. This will be false if the monitor is new + * and has not yet had enough time to fetch the info. + * + * @return true if the monitor has info about the custom endpoint, otherwise returns false. + */ + boolean hasCustomEndpointInfo(); +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointMonitorImpl.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointMonitorImpl.java new file mode 100644 index 000000000..c685d1ca8 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointMonitorImpl.java @@ -0,0 +1,253 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.customendpoint; + +import static software.amazon.jdbc.plugin.customendpoint.MemberListType.STATIC_LIST; + +import java.util.List; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.function.BiFunction; +import java.util.logging.Level; +import java.util.logging.Logger; +import java.util.stream.Collectors; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.awssdk.services.rds.model.DBClusterEndpoint; +import software.amazon.awssdk.services.rds.model.DescribeDbClusterEndpointsResponse; +import software.amazon.awssdk.services.rds.model.Filter; +import software.amazon.jdbc.AllowedAndBlockedHosts; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.CacheMap; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +/** + * The default custom endpoint monitor implementation. This class uses a background thread to monitor a given custom + * endpoint for custom endpoint information and future changes to the custom endpoint. + */ +public class CustomEndpointMonitorImpl implements CustomEndpointMonitor { + private static final Logger LOGGER = Logger.getLogger(CustomEndpointPlugin.class.getName()); + private static final String TELEMETRY_ENDPOINT_INFO_CHANGED = "customEndpoint.infoChanged.counter"; + + // Keys are custom endpoint URLs, values are information objects for the associated custom endpoint. + protected static final CacheMap customEndpointInfoCache = new CacheMap<>(); + protected static final long CUSTOM_ENDPOINT_INFO_EXPIRATION_NANO = TimeUnit.MINUTES.toNanos(5); + + protected final AtomicBoolean stop = new AtomicBoolean(false); + protected final RdsUtils rdsUtils = new RdsUtils(); + protected final RdsClient rdsClient; + protected final HostSpec customEndpointHostSpec; + protected final String endpointIdentifier; + protected final Region region; + protected final long refreshRateNano; + + protected final PluginService pluginService; + protected final ExecutorService monitorExecutor = Executors.newSingleThreadExecutor(runnableTarget -> { + final Thread monitoringThread = new Thread(runnableTarget); + monitoringThread.setDaemon(true); + if (!StringUtils.isNullOrEmpty(monitoringThread.getName())) { + monitoringThread.setName(monitoringThread.getName() + "-cem"); + } + return monitoringThread; + }); + + private final TelemetryCounter infoChangedCounter; + + /** + * Constructs a CustomEndpointMonitorImpl instance for the host specified by {@code customEndpointHostSpec}. + * + * @param pluginService The plugin service to use to update the set of allowed/blocked hosts according to + * the custom endpoint info. + * @param customEndpointHostSpec The host information for the custom endpoint to be monitored. + * @param region The region of the custom endpoint to be monitored. + * @param refreshRateNano Controls how often the custom endpoint information should be fetched and analyzed for + * changes. The value specified should be in nanoseconds. + * @param rdsClientFunc The function to call to create the RDS client that will fetch custom endpoint + * information. + */ + public CustomEndpointMonitorImpl( + PluginService pluginService, + HostSpec customEndpointHostSpec, + String endpointIdentifier, + Region region, + long refreshRateNano, + BiFunction rdsClientFunc) { + this.pluginService = pluginService; + this.customEndpointHostSpec = customEndpointHostSpec; + this.endpointIdentifier = endpointIdentifier; + this.region = region; + this.refreshRateNano = refreshRateNano; + this.rdsClient = rdsClientFunc.apply(customEndpointHostSpec, this.region); + + TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory(); + this.infoChangedCounter = telemetryFactory.createCounter(TELEMETRY_ENDPOINT_INFO_CHANGED); + + this.monitorExecutor.submit(this); + this.monitorExecutor.shutdown(); + } + + /** + * Analyzes a given custom endpoint for changes to custom endpoint information. + */ + @Override + public void run() { + LOGGER.fine( + Messages.get( + "CustomEndpointMonitorImpl.startingMonitor", + new Object[] { this.customEndpointHostSpec.getHost() })); + + try { + while (!this.stop.get() && !Thread.currentThread().isInterrupted()) { + try { + long start = System.nanoTime(); + + final Filter customEndpointFilter = + Filter.builder().name("db-cluster-endpoint-type").values("custom").build(); + final DescribeDbClusterEndpointsResponse endpointsResponse = + this.rdsClient.describeDBClusterEndpoints( + (builder) -> + builder.dbClusterEndpointIdentifier(this.endpointIdentifier).filters(customEndpointFilter)); + + List endpoints = endpointsResponse.dbClusterEndpoints(); + if (endpoints.size() != 1) { + List endpointURLs = + endpoints.stream().map(DBClusterEndpoint::endpoint).collect(Collectors.toList()); + LOGGER.warning( + Messages.get("CustomEndpointMonitorImpl.unexpectedNumberOfEndpoints", + new Object[] { + this.endpointIdentifier, + this.region.id(), + endpoints.size(), + endpointURLs + } + )); + + TimeUnit.NANOSECONDS.sleep(this.refreshRateNano); + continue; + } + + CustomEndpointInfo endpointInfo = CustomEndpointInfo.fromDBClusterEndpoint(endpoints.get(0)); + CustomEndpointInfo cachedEndpointInfo = customEndpointInfoCache.get(this.customEndpointHostSpec.getHost()); + if (cachedEndpointInfo != null && cachedEndpointInfo.equals(endpointInfo)) { + long elapsedTime = System.nanoTime() - start; + long sleepDuration = Math.min(0, this.refreshRateNano - elapsedTime); + TimeUnit.NANOSECONDS.sleep(sleepDuration); + continue; + } + + LOGGER.fine( + Messages.get( + "CustomEndpointMonitorImpl.detectedChangeInCustomEndpointInfo", + new Object[] {this.customEndpointHostSpec.getHost(), endpointInfo})); + + // The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts. + AllowedAndBlockedHosts allowedAndBlockedHosts; + if (STATIC_LIST.equals(endpointInfo.getMemberListType())) { + allowedAndBlockedHosts = new AllowedAndBlockedHosts(endpointInfo.getStaticMembers(), null); + } else { + allowedAndBlockedHosts = new AllowedAndBlockedHosts(null, endpointInfo.getExcludedMembers()); + } + + this.pluginService.setAllowedAndBlockedHosts(allowedAndBlockedHosts); + customEndpointInfoCache.put( + this.customEndpointHostSpec.getHost(), endpointInfo, CUSTOM_ENDPOINT_INFO_EXPIRATION_NANO); + this.infoChangedCounter.inc(); + + long elapsedTime = System.nanoTime() - start; + long sleepDuration = Math.min(0, this.refreshRateNano - elapsedTime); + TimeUnit.NANOSECONDS.sleep(sleepDuration); + } catch (InterruptedException e) { + throw e; + } catch (Exception e) { + // If the exception is not an InterruptedException, log it and continue monitoring. + LOGGER.log(Level.SEVERE, + Messages.get( + "CustomEndpointMonitorImpl.exception", + new Object[]{this.customEndpointHostSpec.getHost()}), e); + } + } + } catch (InterruptedException e) { + LOGGER.info(Messages.get("CustomEndpointMonitorImpl.interrupted", new Object[]{ this.customEndpointHostSpec })); + Thread.currentThread().interrupt(); + } finally { + customEndpointInfoCache.remove(this.customEndpointHostSpec.getHost()); + this.rdsClient.close(); + LOGGER.fine( + Messages.get("CustomEndpointMonitorImpl.stoppedMonitor", new Object[]{ this.customEndpointHostSpec })); + } + } + + public boolean hasCustomEndpointInfo() { + return customEndpointInfoCache.get(this.customEndpointHostSpec.getHost()) != null; + } + + @Override + public boolean shouldDispose() { + return true; + } + + /** + * Stops the custom endpoint monitor. + */ + @Override + public void close() { + LOGGER.fine( + Messages.get( + "CustomEndpointMonitorImpl.stoppingMonitor", + new Object[]{ this.customEndpointHostSpec.getHost() })); + + this.stop.set(true); + + try { + int terminationTimeoutSec = 5; + if (!this.monitorExecutor.awaitTermination(terminationTimeoutSec, TimeUnit.SECONDS)) { + LOGGER.info( + Messages.get( + "CustomEndpointMonitorImpl.monitorTerminationTimeout", + new Object[]{ terminationTimeoutSec, this.customEndpointHostSpec.getHost() })); + + this.monitorExecutor.shutdownNow(); + } + } catch (InterruptedException e) { + LOGGER.info( + Messages.get( + "CustomEndpointMonitorImpl.interruptedWhileTerminating", + new Object[]{ this.customEndpointHostSpec.getHost() })); + + Thread.currentThread().interrupt(); + this.monitorExecutor.shutdownNow(); + } finally { + customEndpointInfoCache.remove(this.customEndpointHostSpec.getHost()); + this.rdsClient.close(); + } + } + + /** + * Clears the shared custom endpoint information cache. + */ + public static void clearCache() { + LOGGER.info(Messages.get("CustomEndpointMonitorImpl.clearCache")); + customEndpointInfoCache.clear(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java new file mode 100644 index 000000000..d35c2819b --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPlugin.java @@ -0,0 +1,318 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.customendpoint; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.Collections; +import java.util.HashSet; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; +import java.util.logging.Logger; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.jdbc.AwsWrapperProperty; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.authentication.AwsCredentialsManager; +import software.amazon.jdbc.plugin.AbstractConnectionPlugin; +import software.amazon.jdbc.util.Messages; +import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.RegionUtils; +import software.amazon.jdbc.util.SlidingExpirationCacheWithCleanupThread; +import software.amazon.jdbc.util.StringUtils; +import software.amazon.jdbc.util.SubscribedMethodHelper; +import software.amazon.jdbc.util.WrapperUtils; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +/** + * A plugin that analyzes custom endpoints for custom endpoint information and custom endpoint changes, such as adding + * or removing an instance in the custom endpoint. + */ +public class CustomEndpointPlugin extends AbstractConnectionPlugin { + private static final Logger LOGGER = Logger.getLogger(CustomEndpointPlugin.class.getName()); + private static final String TELEMETRY_WAIT_FOR_INFO_COUNTER = "customEndpoint.waitForInfo.counter"; + + protected static final long CACHE_CLEANUP_RATE_NANO = TimeUnit.MINUTES.toNanos(1); + protected static final RegionUtils regionUtils = new RegionUtils(); + protected static final SlidingExpirationCacheWithCleanupThread monitors = + new SlidingExpirationCacheWithCleanupThread<>( + CustomEndpointMonitor::shouldDispose, + (monitor) -> { + try { + monitor.close(); + } catch (Exception ex) { + // ignore + } + }, + CACHE_CLEANUP_RATE_NANO); + + private static final Set subscribedMethods = + Collections.unmodifiableSet(new HashSet() { + { + addAll(SubscribedMethodHelper.NETWORK_BOUND_METHODS); + add("connect"); + } + }); + + public static final AwsWrapperProperty CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS = new AwsWrapperProperty( + "customEndpointInfoRefreshRateMs", "30000", + "Controls how frequently custom endpoint monitors fetch custom endpoint info."); + + public static final AwsWrapperProperty WAIT_FOR_CUSTOM_ENDPOINT_INFO = new AwsWrapperProperty( + "waitForCustomEndpointInfo", "true", + "Controls whether to wait for custom endpoint info to become available before connecting or executing a " + + "method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used " + + "recently. Note that disabling this may result in occasional connections to instances outside of the " + + "custom endpoint."); + + public static final AwsWrapperProperty WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = new AwsWrapperProperty( + "waitForCustomEndpointInfoTimeoutMs", "5000", + "Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made " + + "available by the custom endpoint monitor."); + + public static final AwsWrapperProperty CUSTOM_ENDPOINT_MONITOR_IDLE_EXPIRATION_MS = new AwsWrapperProperty( + "customEndpointMonitorExpirationMs", String.valueOf(TimeUnit.MINUTES.toMillis(15)), + "Controls how long a monitor should run without use before expiring and being removed."); + + public static final AwsWrapperProperty REGION_PROPERTY = new AwsWrapperProperty( + "customEndpointRegion", null, + "The region of the cluster's custom endpoints."); + + static { + PropertyDefinition.registerPluginProperties(CustomEndpointPlugin.class); + } + + protected final PluginService pluginService; + protected final Properties props; + protected final RdsUtils rdsUtils = new RdsUtils(); + protected final BiFunction rdsClientFunc; + + protected final TelemetryCounter waitForInfoCounter; + protected final boolean shouldWaitForInfo; + protected final int waitOnCachedInfoDurationMs; + protected final int idleMonitorExpirationMs; + protected HostSpec customEndpointHostSpec; + protected String customEndpointId; + protected Region region; + + /** + * Constructs a new CustomEndpointPlugin instance. + * + * @param pluginService The plugin service that the custom endpoint plugin should use. + * @param props The properties that the custom endpoint plugin should use. + */ + public CustomEndpointPlugin(final PluginService pluginService, final Properties props) { + this( + pluginService, + props, + (hostSpec, region) -> + RdsClient.builder() + .region(region) + .credentialsProvider(AwsCredentialsManager.getProvider(hostSpec, props)) + .build()); + } + + /** + * Constructs a new CustomEndpointPlugin instance. + * + * @param pluginService The plugin service that the custom endpoint plugin should use. + * @param props The properties that the custom endpoint plugin should use. + * @param rdsClientFunc The function to call to obtain an {@link RdsClient} instance. + */ + public CustomEndpointPlugin( + final PluginService pluginService, + final Properties props, + final BiFunction rdsClientFunc) { + this.pluginService = pluginService; + this.props = props; + this.rdsClientFunc = rdsClientFunc; + + this.shouldWaitForInfo = WAIT_FOR_CUSTOM_ENDPOINT_INFO.getBoolean(this.props); + this.waitOnCachedInfoDurationMs = WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.getInteger(this.props); + this.idleMonitorExpirationMs = CUSTOM_ENDPOINT_MONITOR_IDLE_EXPIRATION_MS.getInteger(this.props); + + TelemetryFactory telemetryFactory = pluginService.getTelemetryFactory(); + this.waitForInfoCounter = telemetryFactory.createCounter(TELEMETRY_WAIT_FOR_INFO_COUNTER); + } + + @Override + public Set getSubscribedMethods() { + return subscribedMethods; + } + + @Override + public Connection connect( + final String driverProtocol, + final HostSpec hostSpec, + final Properties props, + final boolean isInitialConnection, + final JdbcCallable connectFunc) + throws SQLException { + if (!this.rdsUtils.isRdsCustomClusterDns(hostSpec.getHost())) { + return connectFunc.call(); + } + + this.customEndpointHostSpec = hostSpec; + LOGGER.finest( + Messages.get( + "CustomEndpointPlugin.connectionRequestToCustomEndpoint", new Object[]{ hostSpec.getHost() })); + + this.customEndpointId = this.rdsUtils.getRdsClusterId(customEndpointHostSpec.getHost()); + if (StringUtils.isNullOrEmpty(customEndpointId)) { + throw new SQLException( + Messages.get( + "CustomEndpointPlugin.errorParsingEndpointIdentifier", + new Object[] {customEndpointHostSpec.getHost()})); + } + + this.region = regionUtils.getRegion(this.customEndpointHostSpec.getHost(), props, REGION_PROPERTY.name); + if (this.region == null) { + throw new SQLException( + Messages.get( + "CustomEndpointPlugin.unableToDetermineRegion", + new Object[] {REGION_PROPERTY.name})); + } + + CustomEndpointMonitor monitor = createMonitorIfAbsent(props); + + if (this.shouldWaitForInfo) { + // If needed, wait a short time for custom endpoint info to be discovered. + waitForCustomEndpointInfo(monitor); + } + + return connectFunc.call(); + } + + /** + * Creates a monitor for the custom endpoint if it does not already exist. + * + * @param props The connection properties. + */ + protected CustomEndpointMonitor createMonitorIfAbsent(Properties props) { + return monitors.computeIfAbsent( + this.customEndpointHostSpec.getHost(), + (customEndpoint) -> new CustomEndpointMonitorImpl( + this.pluginService, + this.customEndpointHostSpec, + this.customEndpointId, + this.region, + TimeUnit.MILLISECONDS.toNanos(CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS.getLong(props)), + this.rdsClientFunc + ), + TimeUnit.MILLISECONDS.toNanos(this.idleMonitorExpirationMs) + ); + } + + + + /** + * If custom endpoint info does not exist for the current custom endpoint, waits a short time for the info to be + * made available by the custom endpoint monitor. This is necessary so that other plugins can rely on accurate custom + * endpoint info. Since custom endpoint monitors and information are shared, we should not have to wait often. + */ + protected void waitForCustomEndpointInfo(CustomEndpointMonitor monitor) throws SQLException { + boolean hasCustomEndpointInfo = monitor.hasCustomEndpointInfo(); + + if (!hasCustomEndpointInfo) { + // Wait for the monitor to place the custom endpoint info in the cache. This ensures other plugins get accurate + // custom endpoint info. + this.waitForInfoCounter.inc(); + LOGGER.fine( + Messages.get( + "CustomEndpointPlugin.waitingForCustomEndpointInfo", + new Object[]{ this.customEndpointHostSpec.getHost(), this.waitOnCachedInfoDurationMs })); + long waitForEndpointInfoTimeoutNano = + System.nanoTime() + TimeUnit.MILLISECONDS.toNanos(this.waitOnCachedInfoDurationMs); + + try { + while (!hasCustomEndpointInfo && System.nanoTime() < waitForEndpointInfoTimeoutNano) { + TimeUnit.MILLISECONDS.sleep(100); + hasCustomEndpointInfo = monitor.hasCustomEndpointInfo(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new SQLException( + Messages.get( + "CustomEndpointPlugin.interruptedThread", + new Object[]{ this.customEndpointHostSpec.getHost() })); + } + + if (!hasCustomEndpointInfo) { + throw new SQLException( + Messages.get("CustomEndpointPlugin.timedOutWaitingForCustomEndpointInfo", + new Object[]{this.waitOnCachedInfoDurationMs, this.customEndpointHostSpec.getHost()})); + } + } + } + + /** + * Executes the given method via a pipeline of plugins. If a custom endpoint is being used, a monitor for that custom + * endpoint will be created if it does not already exist. + * + * @param resultClass The class of the object returned by the {@code jdbcMethodFunc}. + * @param exceptionClass The desired exception class for any exceptions that occur while executing the + * {@code jdbcMethodFunc}. + * @param methodInvokeOn The object that the {@code jdbcMethodFunc} is being invoked on. + * @param methodName The name of the method being invoked. + * @param jdbcMethodFunc The execute pipeline to call to invoke the method. + * @param jdbcMethodArgs The arguments to the method being invoked. + * @param The type of the result returned by the method. + * @param The desired type for any exceptions that occur while executing the {@code jdbcMethodFunc}. + * @return The result of the method invocation. + * @throws E If an exception occurs, either directly in this method, or while executing the {@code jdbcMethodFunc}. + */ + @Override + public T execute( + final Class resultClass, + final Class exceptionClass, + final Object methodInvokeOn, + final String methodName, + final JdbcCallable jdbcMethodFunc, + final Object[] jdbcMethodArgs) + throws E { + if (this.customEndpointHostSpec == null) { + return jdbcMethodFunc.call(); + } + + try { + CustomEndpointMonitor monitor = createMonitorIfAbsent(this.props); + if (this.shouldWaitForInfo) { + // If needed, wait a short time for custom endpoint info to be discovered. + waitForCustomEndpointInfo(monitor); + } + } catch (Exception e) { + throw WrapperUtils.wrapExceptionIfNeeded(exceptionClass, e); + } + + return jdbcMethodFunc.call(); + } + + /** + * Closes all active custom endpoint monitors. + */ + public static void closeMonitors() { + LOGGER.info(Messages.get("CustomEndpointPlugin.closeMonitors")); + // The clear call automatically calls close() on all monitors. + monitors.clear(); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginFactory.java new file mode 100644 index 000000000..6687b9395 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginFactory.java @@ -0,0 +1,37 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + + +package software.amazon.jdbc.plugin.customendpoint; + +import java.util.Properties; +import software.amazon.jdbc.ConnectionPlugin; +import software.amazon.jdbc.ConnectionPluginFactory; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.Messages; + +public class CustomEndpointPluginFactory implements ConnectionPluginFactory { + @Override + public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { + try { + Class.forName("software.amazon.awssdk.services.rds.RdsClient"); + } catch (final ClassNotFoundException e) { + throw new RuntimeException(Messages.get("CustomEndpointPluginFactory.awsSdkNotInClasspath")); + } + + return new CustomEndpointPlugin(pluginService, props); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointRoleType.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointRoleType.java new file mode 100644 index 000000000..5babc72c6 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointRoleType.java @@ -0,0 +1,26 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.customendpoint; + +/** + * Enum representing the possible roles of instances specified by a custom endpoint. Note that, currently, it is not + * possible to create a WRITER custom endpoint. + */ +public enum CustomEndpointRoleType { + ANY, // Instances in the custom endpoint may be either a writer or a reader. + READER // Instances in the custom endpoint are always readers. +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/MemberListType.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/MemberListType.java new file mode 100644 index 000000000..298c28069 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/customendpoint/MemberListType.java @@ -0,0 +1,34 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.customendpoint; + +/** + * Enum representing the member list type of a custom endpoint. This information can be used together with a member list + * to determine which instances are included or excluded from a custom endpoint. + */ +public enum MemberListType { + /** + * The member list for the custom endpoint specifies which instances are included in the custom endpoint. If new + * instances are added to the cluster, they will not be automatically added to the custom endpoint. + */ + STATIC_LIST, + /** + * The member list for the custom endpoint specifies which instances are excluded from the custom endpoint. If new + * instances are added to the cluster, they will be automatically added to the custom endpoint. + */ + EXCLUSION_LIST +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java index fab6744d1..cc012d676 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareReaderFailoverHandler.java @@ -154,7 +154,7 @@ private Future submitInternalFailoverTask( // need to ensure that new connection is a connection to a reader node pluginService.forceRefreshHostList(result.getConnection()); - topology = pluginService.getHosts(); + topology = pluginService.getAllHosts(); for (final HostSpec node : topology) { if (node.getUrl().equals(result.getHost().getUrl())) { // found new connection host in the latest topology diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java index c00d75794..825efe267 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandler.java @@ -259,7 +259,7 @@ public WriterFailoverResult call() { conn = pluginService.forceConnect(this.originalWriterHost, initialConnectionProps); pluginService.forceRefreshHostList(conn); - latestTopology = pluginService.getHosts(); + latestTopology = pluginService.getAllHosts(); } catch (final SQLException exception) { // Propagate exceptions that are not caused by network errors. @@ -405,7 +405,7 @@ private boolean refreshTopologyAndConnectToNewWriter() throws InterruptedExcepti while (true) { try { pluginService.forceRefreshHostList(this.currentReaderConnection); - final List topology = pluginService.getHosts(); + final List topology = pluginService.getAllHosts(); if (!topology.isEmpty()) { @@ -425,7 +425,7 @@ private boolean refreshTopologyAndConnectToNewWriter() throws InterruptedExcepti if (allowOldWriter || !isSame(writerCandidate, this.originalWriterHost)) { // new writer is available, and it's different from the previous writer - LOGGER.finest(() -> Utils.logTopology(this.currentTopology, "[TaskB] ")); + LOGGER.finest(() -> Utils.logTopology(this.currentTopology, "[TaskB] Topology:")); if (connectToWriter(writerCandidate)) { return true; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java index 672f676cc..bc69df679 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPlugin.java @@ -356,7 +356,7 @@ void setRdsUrlType(final RdsUrlType rdsUrlType) { public boolean isFailoverEnabled() { return this.enableFailoverSetting && !RdsUrlType.RDS_PROXY.equals(this.rdsUrlType) - && !Utils.isNullOrEmpty(this.pluginService.getHosts()); + && !Utils.isNullOrEmpty(this.pluginService.getAllHosts()); } private void initSettings() { @@ -391,7 +391,7 @@ private void invalidInvocationOnClosedConnection() throws SQLException { } private HostSpec getCurrentWriter() throws SQLException { - final List topology = this.pluginService.getHosts(); + final List topology = this.pluginService.getAllHosts(); if (topology == null) { return null; } @@ -489,7 +489,7 @@ private void processFailoverFailure(final String message) throws SQLException { private boolean shouldAttemptReaderConnection() { final List topology = this.pluginService.getHosts(); - if (topology == null || this.failoverMode == FailoverMode.STRICT_WRITER) { + if (Utils.isNullOrEmpty(topology) || this.failoverMode == FailoverMode.STRICT_WRITER) { return false; } @@ -598,7 +598,8 @@ protected void failoverReader(final HostSpec failedHostSpec) throws SQLException if (failedHostSpec != null && failedHostSpec.getRawAvailability() == HostAvailability.AVAILABLE) { failedHost = failedHostSpec; } - final ReaderFailoverResult result = readerFailoverHandler.failover(this.pluginService.getHosts(), failedHost); + final ReaderFailoverResult result = + readerFailoverHandler.failover(this.pluginService.getHosts(), failedHost); if (result != null) { final SQLException exception = result.getException(); @@ -652,13 +653,15 @@ protected void failoverWriter() throws SQLException { try { LOGGER.info(() -> Messages.get("Failover.startWriterFailover")); - final WriterFailoverResult failoverResult = this.writerFailoverHandler.failover(this.pluginService.getHosts()); + final WriterFailoverResult failoverResult = + this.writerFailoverHandler.failover(this.pluginService.getAllHosts()); if (failoverResult != null) { final SQLException exception = failoverResult.getException(); if (exception != null) { throw exception; } } + if (failoverResult == null || !failoverResult.isConnected()) { // "Unable to establish SQL connection to writer node" processFailoverFailure(Messages.get("Failover.unableToConnectToWriter")); @@ -668,6 +671,18 @@ protected void failoverWriter() throws SQLException { // successfully re-connected to a writer node final HostSpec writerHostSpec = getWriter(failoverResult.getTopology()); + final List allowedHosts = this.pluginService.getHosts(); + if (!allowedHosts.contains(writerHostSpec)) { + this.failoverWriterFailedCounter.inc(); + processFailoverFailure( + Messages.get("Failover.newWriterNotAllowed", + new Object[] { + writerHostSpec == null ? "" : writerHostSpec.getHost(), + Utils.logTopology(allowedHosts, "") + })); + return; + } + this.pluginService.setCurrentConnection(failoverResult.getNewConnection(), writerHostSpec); LOGGER.fine( @@ -782,7 +797,6 @@ public Connection connect( private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Properties props, boolean isInitialConnection, JdbcCallable connectFunc, boolean isForceConnect) throws SQLException { - Connection conn = null; try { conn = diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java index 1f1f69fda..97ddbce6f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/failover2/FailoverConnectionPlugin.java @@ -253,7 +253,7 @@ void initHostProvider( protected boolean isFailoverEnabled() { return !RdsUrlType.RDS_PROXY.equals(this.rdsUrlType) - && !Utils.isNullOrEmpty(this.pluginService.getHosts()); + && !Utils.isNullOrEmpty(this.pluginService.getAllHosts()); } protected void invalidInvocationOnClosedConnection() throws SQLException { @@ -492,7 +492,7 @@ protected void failoverWriter() throws SQLException { throw new FailoverFailedSQLException(Messages.get("Failover.unableToConnectToWriter")); } - final List updatedHosts = this.pluginService.getHosts(); + final List updatedHosts = this.pluginService.getAllHosts(); final Properties copyProp = PropertyUtils.copyProperties(this.properties); copyProp.setProperty(INTERNAL_CONNECT_PROPERTY_NAME, "true"); @@ -502,6 +502,16 @@ protected void failoverWriter() throws SQLException { .findFirst() .orElse(null); + List allowedHosts = this.pluginService.getHosts(); + if (writerCandidate != null && !allowedHosts.contains(writerCandidate)) { + this.failoverWriterFailedCounter.inc(); + LOGGER.severe(Messages.get("Failover.newWriterNotAllowed", + new Object[] {writerCandidate.getHost(), Utils.logTopology(allowedHosts, "")})); + throw new FailoverFailedSQLException( + Messages.get("Failover.newWriterNotAllowed", + new Object[] {writerCandidate.getHost(), Utils.logTopology(allowedHosts, "")})); + } + if (writerCandidate != null) { try { writerCandidateConn = this.pluginService.connect(writerCandidate, copyProp); @@ -630,7 +640,6 @@ public Connection connect( final boolean isInitialConnection, final JdbcCallable connectFunc) throws SQLException { - // This call was initiated by this failover2 plugin and doesn't require any additional processing. if (props.containsKey(INTERNAL_CONNECT_PROPERTY_NAME)) { return connectFunc.call(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java index 8cf644a76..921f5f892 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java @@ -22,7 +22,6 @@ import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.HashSet; -import java.util.Optional; import java.util.Properties; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; @@ -31,7 +30,6 @@ import org.checkerframework.checker.nullness.qual.NonNull; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.rds.RdsUtilities; import software.amazon.jdbc.AwsWrapperProperty; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; @@ -43,12 +41,11 @@ import software.amazon.jdbc.util.IamAuthUtils; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; -import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; public class FederatedAuthPlugin extends AbstractConnectionPlugin { @@ -93,6 +90,7 @@ public class FederatedAuthPlugin extends AbstractConnectionPlugin { new AwsWrapperProperty("dbUser", null, "The database user used to access the database"); protected static final Pattern SAML_RESPONSE_PATTERN = Pattern.compile("SAMLResponse\\W+value=\"(?[^\"]+)\""); protected static final String SAML_RESPONSE_PATTERN_GROUP = "saml"; + protected static final RegionUtils regionUtils = new RegionUtils(); private static final Logger LOGGER = Logger.getLogger(FederatedAuthPlugin.class.getName()); protected final PluginService pluginService; @@ -187,7 +185,11 @@ private Connection connectInternal(final HostSpec hostSpec, final Properties pro hostSpec, this.pluginService.getDialect().getDefaultPort()); - final Region region = IamAuthUtils.getRegion(this.rdsUtils, IAM_REGION.getString(props), host, props); + final Region region = regionUtils.getRegion(host, props, IAM_REGION.name); + if (region == null) { + throw new SQLException( + Messages.get("FederatedAuthPlugin.unableToDetermineRegion", new Object[]{ IAM_REGION.name })); + } final String cacheKey = IamAuthUtils.getCacheKey( DB_USER.getString(props), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java index afeeaa535..087928626 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java @@ -22,15 +22,12 @@ import java.time.temporal.ChronoUnit; import java.util.Collections; import java.util.HashSet; -import java.util.Optional; import java.util.Properties; import java.util.Set; import java.util.concurrent.ConcurrentHashMap; import java.util.logging.Logger; -import java.util.regex.Pattern; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; -import software.amazon.awssdk.services.rds.RdsUtilities; import software.amazon.jdbc.AwsWrapperProperty; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.JdbcCallable; @@ -42,12 +39,11 @@ import software.amazon.jdbc.util.IamAuthUtils; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; -import software.amazon.jdbc.util.telemetry.TelemetryContext; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; import software.amazon.jdbc.util.telemetry.TelemetryGauge; -import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; public class OktaAuthPlugin extends AbstractConnectionPlugin { @@ -89,6 +85,7 @@ public class OktaAuthPlugin extends AbstractConnectionPlugin { new AwsWrapperProperty("dbUser", null, "The database user used to access the database"); private static final Logger LOGGER = Logger.getLogger(OktaAuthPlugin.class.getName()); + protected static final RegionUtils regionUtils = new RegionUtils(); protected final PluginService pluginService; @@ -160,7 +157,11 @@ private Connection connectInternal(final HostSpec hostSpec, final Properties pro hostSpec, this.pluginService.getDialect().getDefaultPort()); - final Region region = IamAuthUtils.getRegion(this.rdsUtils, IAM_REGION.getString(props), host, props); + final Region region = regionUtils.getRegion(host, props, IAM_REGION.name); + if (region == null) { + throw new SQLException( + Messages.get("OktaAuthPlugin.unableToDetermineRegion", new Object[]{ IAM_REGION.name })); + } final String cacheKey = IamAuthUtils.getCacheKey( DB_USER.getString(props), diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java index 6c6eda4b1..29613da32 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java @@ -39,6 +39,7 @@ import software.amazon.jdbc.util.IamAuthUtils; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.RdsUtils; +import software.amazon.jdbc.util.RegionUtils; import software.amazon.jdbc.util.StringUtils; import software.amazon.jdbc.util.telemetry.TelemetryCounter; import software.amazon.jdbc.util.telemetry.TelemetryFactory; @@ -73,6 +74,7 @@ public class IamAuthConnectionPlugin extends AbstractConnectionPlugin { "iamExpiration", String.valueOf(DEFAULT_TOKEN_EXPIRATION_SEC), "IAM token cache expiration in seconds"); + protected static final RegionUtils regionUtils = new RegionUtils(); protected final PluginService pluginService; protected final RdsUtils rdsUtils = new RdsUtils(); @@ -127,8 +129,11 @@ private Connection connectInternal(String driverProtocol, HostSpec hostSpec, Pro hostSpec, this.pluginService.getDialect().getDefaultPort()); - final String iamRegion = IAM_REGION.getString(props); - final Region region = IamAuthUtils.getRegion(rdsUtils, iamRegion, host, props); + final Region region = regionUtils.getRegion(host, props, IAM_REGION.name); + if (region == null) { + throw new SQLException( + Messages.get("IamAuthConnectionPlugin.unableToDetermineRegion", new Object[]{ IAM_REGION.name })); + } final int tokenExpirationSec = IAM_EXPIRATION.getInteger(props); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterMonitor.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterMonitor.java index 0fead0ce5..d8c0d7cd2 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterMonitor.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/limitless/LimitlessRouterMonitor.java @@ -152,7 +152,7 @@ public void run() { newLimitlessRouters, LimitlessRouterServiceImpl.MONITOR_DISPOSAL_TIME_MS.getLong(props)); RoundRobinHostSelector.setRoundRobinHostWeightPairsProperty(this.props, newLimitlessRouters); - LOGGER.finest(Utils.logTopology(newLimitlessRouters, "[limitlessRouterMonitor]")); + LOGGER.finest(Utils.logTopology(newLimitlessRouters, "[limitlessRouterMonitor] Topology:")); TimeUnit.MILLISECONDS.sleep(this.intervalMs); // do not include this in the telemetry } catch (final InterruptedException exception) { LOGGER.finest( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java index 857436430..eba4649b8 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPlugin.java @@ -42,6 +42,7 @@ import software.amazon.jdbc.plugin.failover.FailoverSQLException; import software.amazon.jdbc.util.Messages; import software.amazon.jdbc.util.SqlState; +import software.amazon.jdbc.util.Utils; import software.amazon.jdbc.util.WrapperUtils; public class ReadWriteSplittingPlugin extends AbstractConnectionPlugin @@ -315,7 +316,7 @@ void switchConnectionIfRequired(final boolean readOnly) throws SQLException { } final List hosts = this.pluginService.getHosts(); - if (hosts == null || hosts.isEmpty()) { + if (Utils.isNullOrEmpty(hosts)) { logAndThrowException(Messages.get("ReadWriteSplittingPlugin.emptyHostList")); } @@ -386,8 +387,8 @@ private void switchToWriterConnection( return; } - this.inReadWriteSplit = true; final HostSpec writerHost = getWriter(hosts); + this.inReadWriteSplit = true; if (!isConnectionUsable(this.writerConnection)) { getNewWriterConnection(writerHost); } else { @@ -426,6 +427,12 @@ private void switchToReaderConnection(final List hosts) return; } + if (this.readerHostSpec != null && !hosts.contains(this.readerHostSpec)) { + // The old reader cannot be used anymore because it is no longer in the list of allowed hosts. + this.readerConnection = null; + this.readerHostSpec = null; + } + this.inReadWriteSplit = true; if (!isConnectionUsable(this.readerConnection)) { initializeReaderConnection(hosts); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java index cfa5767ba..55e4bf656 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/staledns/AuroraStaleDnsHelper.java @@ -95,7 +95,7 @@ public Connection getVerifiedConnection( this.pluginService.refreshHostList(conn); } - LOGGER.finest(() -> Utils.logTopology(this.pluginService.getHosts())); + LOGGER.finest(() -> Utils.logTopology(this.pluginService.getAllHosts())); if (this.writerHostSpec == null) { final HostSpec writerCandidate = this.getWriter(); @@ -135,6 +135,15 @@ public Connection getVerifiedConnection( new Object[]{this.writerHostSpec})); staleDNSDetectedCounter.inc(); + if (!this.pluginService.getHosts().contains(this.writerHostSpec)) { + throw new SQLException( + Messages.get("AuroraStaleDnsHelper.currentWriterNotAllowed", + new Object[] { + this.writerHostSpec == null ? "" : this.writerHostSpec.getHost(), + Utils.logTopology(this.pluginService.getHosts(), "")}) + ); + } + final Connection writerConn = this.pluginService.connect(this.writerHostSpec, props); if (isInitialConnection) { hostListProviderService.setInitialConnectionHostSpec(this.writerHostSpec); @@ -170,7 +179,7 @@ public void notifyNodeListChanged(final Map> } private HostSpec getWriter() { - for (final HostSpec host : this.pluginService.getHosts()) { + for (final HostSpec host : this.pluginService.getAllHosts()) { if (host.getRole() == HostRole.WRITER) { return host; } diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java index e94537439..d12a8df35 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/strategy/fastestresponse/FastestResponseStrategyPlugin.java @@ -154,7 +154,7 @@ public HostSpec getHostSpecByStrategy(final HostRole role, final String strategy final HostSpec fastestResponseHost = cachedFastestResponseHostByRole.get(role.name()); if (fastestResponseHost != null) { - // Found a fastest host. Let find it in the the latest topology. + // Found a fastest host. Let find it in the latest topology. HostSpec foundHostSpec = this.pluginService.getHosts().stream() .filter(x -> x.equals(fastestResponseHost)) .findAny() diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java index fe308b490..0a07773f1 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java @@ -16,9 +16,6 @@ package software.amazon.jdbc.util; -import java.sql.SQLException; -import java.util.Optional; -import java.util.Properties; import java.util.logging.Logger; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.regions.Region; @@ -53,46 +50,6 @@ public static int getIamPort(final int iamDefaultPort, final HostSpec hostSpec, } } - public static Region getRegion( - final RdsUtils rdsUtils, - final String iamRegion, - final String hostname, - final Properties props) throws SQLException { - if (!StringUtils.isNullOrEmpty(iamRegion)) { - return Region.of(iamRegion); - } - - // Fallback to using host - // Get Region - final String rdsRegion = rdsUtils.getRdsRegion(hostname); - - if (StringUtils.isNullOrEmpty(rdsRegion)) { - // Does not match Amazon's Hostname, throw exception - final String exceptionMessage = Messages.get( - "Authentication.unsupportedHostname", - new Object[] {hostname}); - - LOGGER.fine(exceptionMessage); - throw new SQLException(exceptionMessage); - } - - // Check Region - final Optional regionOptional = Region.regions().stream() - .filter(r -> r.id().equalsIgnoreCase(rdsRegion)) - .findFirst(); - - if (!regionOptional.isPresent()) { - final String exceptionMessage = Messages.get( - "AwsSdk.unsupportedRegion", - new Object[] {rdsRegion}); - - LOGGER.fine(exceptionMessage); - throw new SQLException(exceptionMessage); - } - - return regionOptional.get(); - } - public static String getCacheKey( final String user, final String hostname, diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/RdsUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/RdsUtils.java index c9ce393ae..516bfe8be 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/RdsUtils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/RdsUtils.java @@ -202,6 +202,20 @@ public boolean isRdsProxyDns(final String host) { return dnsGroup != null && dnsGroup.startsWith("proxy-"); } + public @Nullable String getRdsClusterId(final String host) { + if (StringUtils.isNullOrEmpty(host)) { + return null; + } + + final Matcher matcher = cacheMatcher(host, + AURORA_DNS_PATTERN, AURORA_CHINA_DNS_PATTERN, AURORA_OLD_CHINA_DNS_PATTERN, AURORA_GOV_DNS_PATTERN); + if (getRegexGroup(matcher, DNS_GROUP) != null) { + return getRegexGroup(matcher, INSTANCE_GROUP); + } + + return null; + } + public @Nullable String getRdsInstanceId(final String host) { if (StringUtils.isNullOrEmpty(host)) { return null; diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/RegionUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/RegionUtils.java new file mode 100644 index 000000000..b13aceb90 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/util/RegionUtils.java @@ -0,0 +1,91 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.util; + +import java.util.Properties; +import org.checkerframework.checker.nullness.qual.Nullable; +import software.amazon.awssdk.regions.Region; + +public class RegionUtils { + protected static final RdsUtils rdsUtils = new RdsUtils(); + + /** + * Determines the AWS region from the given parameters. If the region is defined in the properties, that region will + * be used. Otherwise, attempts to determine the region from the passed in host. + * + * @param host The host from which to extract the region if it is not defined in the properties. + * @param props The connection properties for the connection being established. + * @param propKey The key name of the region property. + * @return The AWS region defined by the properties or extracted from the host, or null if the region was not + * defined in the properties and could not be determined from the {@code host}. + */ + @Nullable + public Region getRegion(String host, Properties props, String propKey) { + Region region = getRegion(props, propKey); + return region != null ? region : getRegionFromHost(host); + } + + /** + * Determines the AWS region from the given properties. + * + * @param props The connection properties for the connection being established. + * @param propKey The key name of the region property. + * @return The AWS region defined by the properties, or null if the region was not defined in the properties. + */ + @Nullable + public Region getRegion(Properties props, String propKey) { + String regionString = props.getProperty(propKey); + if (StringUtils.isNullOrEmpty(regionString)) { + return null; + } + + return getRegionFromRegionString(regionString); + } + + /** + * Determines the AWS region from the given region string. + * + * @param regionString The connection properties for the connection being established. + * @return The AWS region of the given region string. + */ + public Region getRegionFromRegionString(String regionString) { + final Region region = Region.of(regionString); + if (!Region.regions().contains(region)) { + throw new RuntimeException( + Messages.get( + "AwsSdk.unsupportedRegion", + new Object[] {regionString})); + } + + return region; + } + + /** + * Determines the AWS region from the given host string. + * + * @param host The host from which to extract the region. + * @return The AWS region used in the host string. + */ + public Region getRegionFromHost(String host) { + String regionString = rdsUtils.getRdsRegion(host); + if (StringUtils.isNullOrEmpty(regionString)) { + return null; + } + + return getRegionFromRegionString(regionString); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/Utils.java b/wrapper/src/main/java/software/amazon/jdbc/util/Utils.java index 3fa03b3b4..f54c32769 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/Utils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/Utils.java @@ -16,14 +16,14 @@ package software.amazon.jdbc.util; +import java.util.Collection; import java.util.List; -import org.checkerframework.checker.nullness.qual.NonNull; import org.checkerframework.checker.nullness.qual.Nullable; import software.amazon.jdbc.HostSpec; public class Utils { - public static boolean isNullOrEmpty(final List list) { - return list == null || list.isEmpty(); + public static boolean isNullOrEmpty(final Collection c) { + return c == null || c.isEmpty(); } public static String logTopology(final @Nullable List hosts) { @@ -45,7 +45,8 @@ public static String logTopology( msg.append(" ").append(host == null ? "" : host); } } - return (messagePrefix == null ? "" : messagePrefix) - + Messages.get("Utils.topology", new Object[] {msg.toString()}); + + return Messages.get("Utils.topology", + new Object[] {messagePrefix == null ? "Topology:" : messagePrefix, msg.toString()}); } } diff --git a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties index 40bf87093..422ab9972 100644 --- a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties +++ b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties @@ -121,6 +121,31 @@ ConsoleConsumer.unexpectedOutputType=Unexpected outputType: ''{0}''. CredentialsProviderFactory.failedToInitializeHttpClient=Failed to initialize HttpClient. CredentialsProviderFactory.unsupportedIdp=Unsupported Identity Provider ''{0}''. Please visit to the documentation for supported Identity Providers. +# Custom Endpoint Monitor Impl +CustomEndpointMonitorImpl.clearCache=Clearing info in the custom endpoint monitor info cache. +CustomEndpointMonitorImpl.detectedChangeInCustomEndpointInfo=Detected change in custom endpoint info for ''{0}'':\n{1} +CustomEndpointMonitorImpl.exception=Encountered an exception while monitoring custom endpoint ''{0}''. +CustomEndpointMonitorImpl.interrupted=Custom endpoint monitor for ''{0}'' was interrupted. +CustomEndpointMonitorImpl.interruptedWhileTerminating=Interrupted while awaiting termination of custom endpoint monitor for ''{0}''. The monitor will be forcefully shut down. +CustomEndpointMonitorImpl.monitorTerminationTimeout=Timed out after waiting {0} seconds for custom endpoint monitor for ''{1}'' to terminate gracefully. The monitor will be forcefully shut down. +CustomEndpointMonitorImpl.startingMonitor=Starting custom endpoint monitor for ''{0}''. +CustomEndpointMonitorImpl.stoppedMonitor=Stopped custom endpoint monitor for ''{0}''. +CustomEndpointMonitorImpl.stoppingMonitor=Stopping custom endpoint monitor for ''{0}''. +CustomEndpointMonitorImpl.unexpectedNumberOfEndpoints=Unexpected number of custom endpoints with endpoint identifier ''{0}'' in region ''{1}''. Expected 1, but found {2}. Endpoints:\n{3}. + +# Custom Endpoint Plugin +CustomEndpointPlugin.timedOutWaitingForCustomEndpointInfo=The custom endpoint plugin timed out after {0}ms while waiting for custom endpoint info for host ''{1}''. +CustomEndpointPlugin.closeMonitors=Closing custom endpoint monitors. Active custom endpoint monitors will be stopped, closed, and removed from the monitors cache. +CustomEndpointPlugin.connectionRequestToCustomEndpoint=Detected a connection request to a custom endpoint URL: ''{0}''. +CustomEndpointPlugin.errorParsingEndpointIdentifier=Unable to parse custom endpoint identifier from URL: ''{0}''. +CustomEndpointPlugin.foundInfoInCache=Done waiting for custom endpoint info for ''{0}'':\n{1} +CustomEndpointPlugin.interruptedThread=The custom endpoint plugin was interrupted while waiting for custom endpoint info for host ''{0}''. +CustomEndpointPlugin.unableToDetermineRegion=Unable to determine connection region. If you are using a non-standard RDS URL, please set the ''{0}'' property. +CustomEndpointPlugin.waitingForCustomEndpointInfo=Custom endpoint info for ''{0}'' was not found. Waiting {1}ms for the endpoint monitor to fetch info... + +# Custom Endpoint Plugin Factory +CustomEndpointPluginFactory.awsSdkNotInClasspath=Required dependency 'AWS Java SDK RDS v2.x' is not on the classpath. + # Data Cache Connection Plugin DataCacheConnectionPlugin.queryResultsCached=[{0}] Query results will be cached: {1} @@ -156,10 +181,14 @@ Failover.establishedConnection=Connected to: {0} Failover.startWriterFailover=Starting writer failover procedure. Failover.startReaderFailover=Starting reader failover procedure. Failover.invalidNode=Node is no longer available in the topology: {0} +Failover.newWriterNotAllowed=The failover process identified the new writer but the host is not in the list of allowed hosts. New writer host: ''{0}''. Allowed hosts: {1} Failover.noOperationsAfterConnectionClosed=No operations allowed after connection closed. Failover.readerFailoverElapsed=Reader failover elapsed in {0} ms. Failover.writerFailoverElapsed=Writer failover elapsed in {0} ms. +# Federated Auth Plugin +FederatedAuthPlugin.unableToDetermineRegion=Unable to determine connection region. If you are using a non-standard RDS URL, please set the ''{0}'' property. + # HikariPooledConnectionProvider HikariPooledConnectionProvider.errorConnectingWithDataSource=Unable to connect to ''{0}'' using the Hikari data source. HikariPooledConnectionProvider.errorConnectingWithDataSourceWithCause=Unable to connect to ''{0}'' using the Hikari data source. Exception message: ''{1}'' @@ -183,7 +212,7 @@ HostSelector.roundRobinInvalidDefaultWeight=The provided default weight value is # IAM Auth Connection Plugin IamAuthConnectionPlugin.unhandledException=Unhandled exception: ''{0}'' IamAuthConnectionPlugin.connectException=Error occurred while opening a connection: ''{0}'' -IamAuthConnectionPlugin.missingRequiredConfigParameter=Configuration parameter ''{0}'' is required. +IamAuthConnectionPlugin.unableToDetermineRegion=Unable to determine connection region. If you are using a non-standard RDS URL, please set the ''{0}'' property. # Limitless Connection Plugin LimitlessConnectionPlugin.connectWithHost=Connecting to host {0}. @@ -242,6 +271,7 @@ MonitorImpl.stopMonitoringThread=Stop monitoring thread for {0}. MonitorServiceImpl.emptyAliasSet=Empty alias set passed for ''{0}''. Set should not be empty. MonitorServiceImpl.errorPopulatingAliases=Error occurred while populating aliases: ''{0}''. +OktaAuthPlugin.unableToDetermineRegion=Unable to determine connection region. If you are using a non-standard RDS URL, please set the ''{0}'' property. OktaAuthPlugin.requiredDependenciesMissing=OktaAuthPlugin requires the 'AWS Java SDK for AWS Secret Token Service' and 'JSoup' dependencies. Both of these dependencies must be registered on the classpath. OktaCredentialsProviderFactory.sessionTokenRequestFailed=Failed to retrieve session token from Okta, please ensure the provided Okta username, password and endpoint are correct. OktaCredentialsProviderFactory.invalidSessionToken=Invalid response from session token request to Okta. @@ -250,6 +280,7 @@ OktaCredentialsProviderFactory.invalidSamlResponse=The SAML Assertion request di OktaCredentialsProviderFactory.samlRequestFailed=Okta SAML Assertion request failed with HTTP status ''{0}'', reason phrase ''{1}'', and response ''{2}'' # Plugin Service Impl +PluginServiceImpl.currentHostNotAllowed=The current host is not in the list of allowed hosts. Current host: ''{0}''. Allowed hosts: {1} PluginServiceImpl.hostListEmpty=Current host list is empty. PluginServiceImpl.releaseResources=Releasing resources. PluginServiceImpl.hostListException=Exception while getting a host list. @@ -278,7 +309,7 @@ ReadWriteSplittingPlugin.fallbackToWriter=Failed to switch to a reader. {0}. The ReadWriteSplittingPlugin.switchedFromWriterToReader=Switched from a writer to a reader host. New reader host: ''{0}'' ReadWriteSplittingPlugin.switchedFromReaderToWriter=Switched from a reader to a writer host. New writer host: ''{0}'' ReadWriteSplittingPlugin.settingCurrentConnection=Setting the current connection to ''{0}'' -ReadWriteSplittingPlugin.noWriterFound=No writer was found in the current host list. +ReadWriteSplittingPlugin.noWriterFound=No writer was found in the current host list. This may occur if the writer is not in the list of allowed hosts. ReadWriteSplittingPlugin.noReadersFound=A reader instance was requested via setReadOnly, but there are no readers in the host list. The current writer will be used as a fallback: ''{0}'' ReadWriteSplittingPlugin.emptyHostList=Host list is empty. ReadWriteSplittingPlugin.exceptionWhileExecutingCommand=Detected an exception while executing a command: ''{0}'' @@ -301,6 +332,7 @@ WrapperUtils.failedToInitializeClass=Can''t initialize class ''{0}''. # Aurora Stale DNS AuroraStaleDnsPlugin.requireDynamicProvider=Dynamic host list provider is required. AuroraStaleDnsHelper.clusterEndpointDns=Cluster endpoint resolves to {0}. +AuroraStaleDnsHelper.currentWriterNotAllowed=The current writer is not in the list of allowed hosts. Current host: ''{0}''. Allowed hosts: {1} AuroraStaleDnsHelper.writerHostSpec=Writer host: {0} AuroraStaleDnsHelper.writerInetAddress=Writer host address: {0} AuroraStaleDnsHelper.staleDnsDetected=Stale DNS data detected. Opening a connection to ''{0}''. @@ -311,7 +343,7 @@ OpenedConnectionTracker.unableToPopulateOpenedConnectionQueue=The driver is unab OpenedConnectionTracker.invalidatingConnections=Invalidating opened connections to host: ''{0}'' # Util -Utils.topology=Topology: \n{0} +Utils.topology={0} \n{1} # Dialect Manager DialectManager.unknownDialectCode=Unknown dialect code: ''{0}''. diff --git a/wrapper/src/test/java/integration/container/TestDriverProvider.java b/wrapper/src/test/java/integration/container/TestDriverProvider.java index afd434347..502d65331 100644 --- a/wrapper/src/test/java/integration/container/TestDriverProvider.java +++ b/wrapper/src/test/java/integration/container/TestDriverProvider.java @@ -58,6 +58,8 @@ import software.amazon.jdbc.HikariPooledConnectionProvider; import software.amazon.jdbc.dialect.DialectManager; import software.amazon.jdbc.hostlistprovider.monitoring.MonitoringRdsHostListProvider; +import software.amazon.jdbc.plugin.customendpoint.CustomEndpointMonitorImpl; +import software.amazon.jdbc.plugin.customendpoint.CustomEndpointPlugin; import software.amazon.jdbc.plugin.efm.MonitorThreadContainer; import software.amazon.jdbc.plugin.efm2.MonitorServiceImpl; import software.amazon.jdbc.targetdriverdialect.TargetDriverDialectManager; @@ -234,6 +236,8 @@ private static void clearCaches() { software.amazon.jdbc.plugin.efm2.MonitorServiceImpl.clearCache(); HikariPooledConnectionProvider.clearCache(); MonitoringRdsHostListProvider.clearCache(); + CustomEndpointPlugin.closeMonitors(); + CustomEndpointMonitorImpl.clearCache(); } private static void checkClusterHealth(final boolean makeSureFirstInstanceWriter) diff --git a/wrapper/src/test/java/integration/container/tests/AutoscalingTests.java b/wrapper/src/test/java/integration/container/tests/AutoscalingTests.java index 64cb2ecb9..5a4c61979 100644 --- a/wrapper/src/test/java/integration/container/tests/AutoscalingTests.java +++ b/wrapper/src/test/java/integration/container/tests/AutoscalingTests.java @@ -61,7 +61,7 @@ @EnableOnDatabaseEngineDeployment({DatabaseEngineDeployment.AURORA}) @EnableOnNumOfInstances(min = 5) @MakeSureFirstInstanceWriter -@Order(16) +@Order(17) public class AutoscalingTests { protected static final AuroraTestUtility auroraUtil = AuroraTestUtility.getUtility(); diff --git a/wrapper/src/test/java/integration/container/tests/CustomEndpointTest.java b/wrapper/src/test/java/integration/container/tests/CustomEndpointTest.java new file mode 100644 index 000000000..caf5aef19 --- /dev/null +++ b/wrapper/src/test/java/integration/container/tests/CustomEndpointTest.java @@ -0,0 +1,359 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package integration.container.tests; + +import static org.junit.jupiter.api.Assertions.assertDoesNotThrow; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +import integration.DatabaseEngineDeployment; +import integration.TestDatabaseInfo; +import integration.TestEnvironmentFeatures; +import integration.TestEnvironmentInfo; +import integration.TestInstanceInfo; +import integration.container.ConnectionStringHelper; +import integration.container.TestDriverProvider; +import integration.container.TestEnvironment; +import integration.container.condition.DisableOnTestFeature; +import integration.container.condition.EnableOnDatabaseEngineDeployment; +import integration.container.condition.EnableOnNumOfInstances; +import integration.container.condition.MakeSureFirstInstanceWriter; +import integration.util.AuroraTestUtility; +import java.sql.Connection; +import java.sql.DriverManager; +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.logging.Logger; +import java.util.stream.Collectors; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.MethodOrderer; +import org.junit.jupiter.api.Order; +import org.junit.jupiter.api.TestMethodOrder; +import org.junit.jupiter.api.TestTemplate; +import org.junit.jupiter.api.extension.ExtendWith; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.awssdk.services.rds.model.DBClusterEndpoint; +import software.amazon.awssdk.services.rds.model.DbClusterEndpointNotFoundException; +import software.amazon.awssdk.services.rds.model.DescribeDbClusterEndpointsResponse; +import software.amazon.awssdk.services.rds.model.Filter; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.plugin.failover.FailoverSuccessSQLException; +import software.amazon.jdbc.plugin.readwritesplitting.ReadWriteSplittingSQLException; + +@TestMethodOrder(MethodOrderer.MethodName.class) +@ExtendWith(TestDriverProvider.class) +@EnableOnDatabaseEngineDeployment({DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ_CLUSTER}) +@DisableOnTestFeature({ + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_HIBERNATE_TESTS_ONLY, + TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY}) +@EnableOnNumOfInstances(min = 3) +@MakeSureFirstInstanceWriter +@Order(16) +public class CustomEndpointTest { + private static final Logger LOGGER = Logger.getLogger(CustomEndpointTest.class.getName()); + protected static final String oneInstanceEndpointId = "test-endpoint-1"; + protected static final String twoInstanceEndpointId = "test-endpoint-2"; + protected static final Map endpoints = new HashMap() {{ + put(oneInstanceEndpointId, null); + put(twoInstanceEndpointId, null); + }}; + + protected static final AuroraTestUtility auroraUtil = AuroraTestUtility.getUtility(); + protected static final boolean reuseExistingEndpoints = false; + + protected String currentWriter; + + @BeforeAll + public static void createEndpoints() { + TestEnvironmentInfo envInfo = TestEnvironment.getCurrent().getInfo(); + String clusterId = envInfo.getAuroraClusterName(); + String region = envInfo.getRegion(); + + try (RdsClient client = RdsClient.builder().region(Region.of(region)).build()) { + if (reuseExistingEndpoints) { + waitUntilEndpointsAvailable(client, clusterId); + return; + } + + // Delete pre-existing custom endpoints in case they weren't cleaned up in a previous run. + deleteEndpoints(client); + + List instances = envInfo.getDatabaseInfo().getInstances(); + createEndpoint(client, clusterId, oneInstanceEndpointId, instances.subList(0, 1)); + createEndpoint(client, clusterId, twoInstanceEndpointId, instances.subList(0, 2)); + waitUntilEndpointsAvailable(client, clusterId); + } + } + + private static void deleteEndpoints(RdsClient client) { + for (String endpointId : endpoints.keySet()) { + try { + client.deleteDBClusterEndpoint((builder) -> builder.dbClusterEndpointIdentifier(endpointId)); + } catch (DbClusterEndpointNotFoundException e) { + // Custom endpoint already does not exist - do nothing. + } + } + + waitUntilEndpointsDeleted(client); + } + + private static void waitUntilEndpointsDeleted(RdsClient client) { + String clusterId = TestEnvironment.getCurrent().getInfo().getAuroraClusterName(); + long deleteTimeoutNano = System.nanoTime() + TimeUnit.MINUTES.toNanos(5); + boolean allEndpointsDeleted = false; + + while (!allEndpointsDeleted && System.nanoTime() < deleteTimeoutNano) { + Filter customEndpointFilter = + Filter.builder().name("db-cluster-endpoint-type").values("custom").build(); + DescribeDbClusterEndpointsResponse endpointsResponse = client.describeDBClusterEndpoints( + (builder) -> + builder.dbClusterIdentifier(clusterId).filters(customEndpointFilter)); + List responseIDs = endpointsResponse.dbClusterEndpoints().stream() + .map(DBClusterEndpoint::dbClusterEndpointIdentifier).collect(Collectors.toList()); + + allEndpointsDeleted = endpoints.keySet().stream().noneMatch(responseIDs::contains); + } + + if (!allEndpointsDeleted) { + throw new RuntimeException( + "The test setup step timed out while attempting to delete pre-existing test custom endpoints."); + } + } + + private static void createEndpoint( + RdsClient client, String clusterId, String endpointId, List instances) { + List instanceIDs = instances.stream().map(TestInstanceInfo::getInstanceId).collect(Collectors.toList()); + client.createDBClusterEndpoint((builder) -> + builder.dbClusterEndpointIdentifier(endpointId) + .dbClusterIdentifier(clusterId) + .endpointType("ANY") + .staticMembers(instanceIDs)); + } + + public static void waitUntilEndpointsAvailable(RdsClient client, String clusterId) { + long timeoutEndNano = System.nanoTime() + TimeUnit.MINUTES.toNanos(5); + boolean allEndpointsAvailable = false; + + while (!allEndpointsAvailable && System.nanoTime() < timeoutEndNano) { + Filter customEndpointFilter = + Filter.builder().name("db-cluster-endpoint-type").values("custom").build(); + DescribeDbClusterEndpointsResponse endpointsResponse = client.describeDBClusterEndpoints( + (builder) -> + builder.dbClusterIdentifier(clusterId).filters(customEndpointFilter)); + List responseEndpoints = endpointsResponse.dbClusterEndpoints(); + + int numAvailableEndpoints = 0; + for (int i = 0; i < responseEndpoints.size() && numAvailableEndpoints < endpoints.size(); i++) { + DBClusterEndpoint endpoint = responseEndpoints.get(i); + String endpointId = endpoint.dbClusterEndpointIdentifier(); + if (endpoints.containsKey(endpointId)) { + endpoints.put(endpointId, endpoint); + if ("available".equals(endpoint.status())) { + numAvailableEndpoints++; + } + } + } + + allEndpointsAvailable = numAvailableEndpoints == endpoints.size(); + } + + if (!allEndpointsAvailable) { + throw new RuntimeException( + "The test setup step timed out while waiting for the new custom endpoints to become available."); + } + } + + public static void waitUntilEndpointHasCorrectState(RdsClient client, String endpointId, List membersList) { + long start = System.nanoTime(); + + // Convert to set for later comparison. + Set members = new HashSet<>(membersList); + long timeoutEndNano = System.nanoTime() + TimeUnit.MINUTES.toNanos(20); + boolean hasCorrectState = false; + while (!hasCorrectState && System.nanoTime() < timeoutEndNano) { + DescribeDbClusterEndpointsResponse response = client.describeDBClusterEndpoints( + (builder) -> + builder.dbClusterEndpointIdentifier(endpointId)); + if (response.dbClusterEndpoints().size() != 1) { + fail("Unexpected number of endpoints returned while waiting for custom endpoint to have the specified list of " + + "members. Expected 1, got " + response.dbClusterEndpoints().size()); + } + + DBClusterEndpoint endpoint = response.dbClusterEndpoints().get(0); + // Compare sets to ignore order when checking for members equality. + Set responseMembers = new HashSet<>(endpoint.staticMembers()); + hasCorrectState = responseMembers.equals(members) && "available".equals(endpoint.status()); + } + + if (!hasCorrectState) { + fail("Timed out while waiting for the custom endpoint to stabilize"); + } + + LOGGER.fine("waitUntilEndpointHasCorrectState took " + + TimeUnit.NANOSECONDS.toSeconds(System.nanoTime() - start) + " seconds"); + } + + @BeforeEach + public void identifyWriter() { + this.currentWriter = + TestEnvironment.getCurrent().getInfo().getDatabaseInfo().getInstances().get(0) + .getInstanceId(); + } + + @AfterAll + public static void cleanup() { + if (reuseExistingEndpoints) { + return; + } + + String region = TestEnvironment.getCurrent().getInfo().getRegion(); + try (RdsClient client = RdsClient.builder().region(Region.of(region)).build()) { + deleteEndpoints(client); + } + } + + protected Properties initDefaultProps() { + final Properties props = ConnectionStringHelper.getDefaultProperties(); + props.setProperty(PropertyDefinition.PLUGINS.name, "customEndpoint,readWriteSplitting,failover"); + PropertyDefinition.CONNECT_TIMEOUT.set(props, "10000"); + PropertyDefinition.SOCKET_TIMEOUT.set(props, "10000"); + return props; + } + + @TestTemplate + public void testCustomEndpointFailover() throws SQLException, InterruptedException { + // The single-instance endpoint will be used for this test. + final DBClusterEndpoint endpoint = endpoints.get(oneInstanceEndpointId); + final TestDatabaseInfo dbInfo = TestEnvironment.getCurrent().getInfo().getDatabaseInfo(); + final int port = dbInfo.getClusterEndpointPort(); + final Properties props = initDefaultProps(); + props.setProperty("failoverMode", "reader-or-writer"); + + try (final Connection conn = DriverManager.getConnection( + ConnectionStringHelper.getWrapperUrl(endpoint.endpoint(), port, dbInfo.getDefaultDbName()), + props)) { + List endpointMembers = endpoint.staticMembers(); + String instanceId = auroraUtil.queryInstanceId(conn); + assertTrue(endpointMembers.contains(instanceId)); + + // Use failover API to break connection. + if (instanceId.equals(this.currentWriter)) { + auroraUtil.failoverClusterAndWaitUntilWriterChanged(); + } else { + auroraUtil.failoverClusterToATargetAndWaitUntilWriterChanged(this.currentWriter, instanceId); + } + + assertThrows(FailoverSuccessSQLException.class, () -> auroraUtil.queryInstanceId(conn)); + + String newInstanceId = auroraUtil.queryInstanceId(conn); + assertTrue(endpointMembers.contains(newInstanceId)); + } + } + + @TestTemplate + public void testCustomEndpointReadWriteSplitting_withCustomEndpointChanges() throws SQLException { + // The one-instance custom endpoint will be used for this test. + final DBClusterEndpoint testEndpoint = endpoints.get(oneInstanceEndpointId); + TestEnvironmentInfo envInfo = TestEnvironment.getCurrent().getInfo(); + final TestDatabaseInfo dbInfo = envInfo.getDatabaseInfo(); + final int port = dbInfo.getClusterEndpointPort(); + final Properties props = initDefaultProps(); + // This setting is not required for the test, but it allows us to also test re-creation of expired monitors since it + // takes more than 30 seconds to modify the cluster endpoint (usually around 140s). + props.setProperty("customEndpointMonitorExpirationMs", "30000"); + + try (final Connection conn = + DriverManager.getConnection( + ConnectionStringHelper.getWrapperUrl(testEndpoint.endpoint(), port, dbInfo.getDefaultDbName()), + props); + final RdsClient client = RdsClient.builder().region(Region.of(envInfo.getRegion())).build()) { + List endpointMembers = testEndpoint.staticMembers(); + String instanceId1 = auroraUtil.queryInstanceId(conn); + assertTrue(endpointMembers.contains(instanceId1)); + + // Attempt to switch to an instance of the opposite role. This should fail since the custom endpoint consists only + // of the current host. + boolean newReadOnlyValue = currentWriter.equals(instanceId1); + if (newReadOnlyValue) { + // We are connected to the writer. Attempting to switch to the reader will not work but will intentionally not + // throw an exception. In this scenario we log a warning and purposefully stick with the writer. + LOGGER.fine("Initial connection is to the writer. Attempting to switch to reader..."); + conn.setReadOnly(newReadOnlyValue); + String newInstanceId = auroraUtil.queryInstanceId(conn); + assertEquals(instanceId1, newInstanceId); + } else { + // We are connected to the reader. Attempting to switch to the writer will throw an exception. + LOGGER.fine("Initial connection is to a reader. Attempting to switch to writer..."); + assertThrows(ReadWriteSplittingSQLException.class, () -> conn.setReadOnly(newReadOnlyValue)); + } + + String newMember; + if (currentWriter.equals(instanceId1)) { + newMember = dbInfo.getInstances().get(1).getInstanceId(); + } else { + newMember = currentWriter; + } + + client.modifyDBClusterEndpoint( + builder -> + builder.dbClusterEndpointIdentifier(oneInstanceEndpointId).staticMembers(instanceId1, newMember)); + + try { + waitUntilEndpointHasCorrectState(client, oneInstanceEndpointId, Arrays.asList(instanceId1, newMember)); + + // We should now be able to switch to newMember. + assertDoesNotThrow(() -> conn.setReadOnly(newReadOnlyValue)); + String instanceId2 = auroraUtil.queryInstanceId(conn); + assertEquals(instanceId2, newMember); + + // Switch back to original instance. + conn.setReadOnly(!newReadOnlyValue); + } finally { + client.modifyDBClusterEndpoint( + builder -> + builder.dbClusterEndpointIdentifier(oneInstanceEndpointId).staticMembers(instanceId1)); + waitUntilEndpointHasCorrectState(client, oneInstanceEndpointId, Collections.singletonList(instanceId1)); + } + + // We should not be able to switch again because newMember was removed from the custom endpoint. + if (newReadOnlyValue) { + // We are connected to the writer. Attempting to switch to the reader will not work but will intentionally not + // throw an exception. In this scenario we log a warning and purposefully stick with the writer. + conn.setReadOnly(newReadOnlyValue); + String newInstanceId = auroraUtil.queryInstanceId(conn); + assertEquals(instanceId1, newInstanceId); + } else { + // We are connected to the reader. Attempting to switch to the writer will throw an exception. + assertThrows(ReadWriteSplittingSQLException.class, () -> conn.setReadOnly(newReadOnlyValue)); + } + } + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java index 2fca72b60..d9af2a07f 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/PluginServiceImplTests.java @@ -389,13 +389,13 @@ public void testSetNodeListAdded() throws SQLException { mockTargetDriverDialect, configurationProfile, sessionStateService)); - target.hosts = new ArrayList<>(); + target.allHosts = new ArrayList<>(); target.hostListProvider = hostListProvider; target.refreshHostList(); - assertEquals(1, target.getHosts().size()); - assertEquals("hostA", target.getHosts().get(0).getHost()); + assertEquals(1, target.getAllHosts().size()); + assertEquals("hostA", target.getAllHosts().get(0).getHost()); verify(pluginManager, times(1)).notifyNodeListChanged(any()); Map> notifiedChanges = argumentChangesMap.getValue(); @@ -423,15 +423,15 @@ public void testSetNodeListDeleted() throws SQLException { mockTargetDriverDialect, configurationProfile, sessionStateService)); - target.hosts = Arrays.asList( + target.allHosts = Arrays.asList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostA").build(), new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("hostB").build()); target.hostListProvider = hostListProvider; target.refreshHostList(); - assertEquals(1, target.getHosts().size()); - assertEquals("hostB", target.getHosts().get(0).getHost()); + assertEquals(1, target.getAllHosts().size()); + assertEquals("hostB", target.getAllHosts().get(0).getHost()); verify(pluginManager, times(1)).notifyNodeListChanged(any()); Map> notifiedChanges = argumentChangesMap.getValue(); @@ -460,14 +460,14 @@ public void testSetNodeListChanged() throws SQLException { mockTargetDriverDialect, configurationProfile, sessionStateService)); - target.hosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.WRITER).build()); target.hostListProvider = hostListProvider; target.refreshHostList(); - assertEquals(1, target.getHosts().size()); - assertEquals("hostA", target.getHosts().get(0).getHost()); + assertEquals(1, target.getAllHosts().size()); + assertEquals("hostA", target.getAllHosts().get(0).getHost()); verify(pluginManager, times(1)).notifyNodeListChanged(any()); Map> notifiedChanges = argumentChangesMap.getValue(); @@ -497,14 +497,14 @@ public void testSetNodeListNoChanges() throws SQLException { mockTargetDriverDialect, configurationProfile, sessionStateService)); - target.hosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + target.allHosts = Collections.singletonList(new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).build()); target.hostListProvider = hostListProvider; target.refreshHostList(); - assertEquals(1, target.getHosts().size()); - assertEquals("hostA", target.getHosts().get(0).getHost()); + assertEquals(1, target.getAllHosts().size()); + assertEquals("hostA", target.getAllHosts().get(0).getHost()); verify(pluginManager, times(0)).notifyNodeListChanged(any()); } @@ -523,7 +523,7 @@ public void testNodeAvailabilityNotChanged() throws SQLException { mockTargetDriverDialect, configurationProfile, sessionStateService)); - target.hosts = Collections.singletonList( + target.allHosts = Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) .build()); @@ -532,8 +532,8 @@ public void testNodeAvailabilityNotChanged() throws SQLException { aliases.add("hostA"); target.setAvailability(aliases, HostAvailability.AVAILABLE); - assertEquals(1, target.getHosts().size()); - assertEquals(HostAvailability.AVAILABLE, target.getHosts().get(0).getAvailability()); + assertEquals(1, target.getAllHosts().size()); + assertEquals(HostAvailability.AVAILABLE, target.getAllHosts().get(0).getAvailability()); verify(pluginManager, never()).notifyNodeListChanged(any()); } @@ -552,7 +552,7 @@ public void testNodeAvailabilityChanged_WentDown() throws SQLException { mockTargetDriverDialect, configurationProfile, sessionStateService)); - target.hosts = Collections.singletonList( + target.allHosts = Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.AVAILABLE) .build()); @@ -561,8 +561,8 @@ public void testNodeAvailabilityChanged_WentDown() throws SQLException { aliases.add("hostA"); target.setAvailability(aliases, HostAvailability.NOT_AVAILABLE); - assertEquals(1, target.getHosts().size()); - assertEquals(HostAvailability.NOT_AVAILABLE, target.getHosts().get(0).getAvailability()); + assertEquals(1, target.getAllHosts().size()); + assertEquals(HostAvailability.NOT_AVAILABLE, target.getAllHosts().get(0).getAvailability()); verify(pluginManager, times(1)).notifyNodeListChanged(any()); Map> notifiedChanges = argumentChangesMap.getValue(); @@ -588,7 +588,7 @@ public void testNodeAvailabilityChanged_WentUp() throws SQLException { mockTargetDriverDialect, configurationProfile, sessionStateService)); - target.hosts = Collections.singletonList( + target.allHosts = Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) .host("hostA").port(HostSpec.NO_PORT).role(HostRole.READER).availability(HostAvailability.NOT_AVAILABLE) .build()); @@ -597,8 +597,8 @@ public void testNodeAvailabilityChanged_WentUp() throws SQLException { aliases.add("hostA"); target.setAvailability(aliases, HostAvailability.AVAILABLE); - assertEquals(1, target.getHosts().size()); - assertEquals(HostAvailability.AVAILABLE, target.getHosts().get(0).getAvailability()); + assertEquals(1, target.getAllHosts().size()); + assertEquals(HostAvailability.AVAILABLE, target.getAllHosts().get(0).getAvailability()); verify(pluginManager, times(1)).notifyNodeListChanged(any()); Map> notifiedChanges = argumentChangesMap.getValue(); @@ -636,7 +636,7 @@ public void testNodeAvailabilityChanged_WentUp_ByAlias() throws SQLException { configurationProfile, sessionStateService)); - target.hosts = Arrays.asList(hostA, hostB); + target.allHosts = Arrays.asList(hostA, hostB); Set aliases = new HashSet<>(); aliases.add("hostA.custom.domain.com"); @@ -681,7 +681,7 @@ public void testNodeAvailabilityChanged_WentUp_MultipleHostsByAlias() throws SQL configurationProfile, sessionStateService)); - target.hosts = Arrays.asList(hostA, hostB); + target.allHosts = Arrays.asList(hostA, hostB); Set aliases = new HashSet<>(); aliases.add("ip-10-10-10-10"); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java index 426ab6b0d..13fbfec40 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/AuroraConnectionTrackerPluginTest.java @@ -131,7 +131,7 @@ public void testInvalidateOpenedConnectionsWhenWriterHostNotChange() throws SQLE .build(); // Host list changes during simulated failover - when(mockPluginService.getHosts()).thenReturn(Collections.singletonList(originalHost)); + when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList(originalHost)); doThrow(expectedException).when(mockSqlFunction).call(); final AuroraConnectionTrackerPlugin plugin = new AuroraConnectionTrackerPlugin( @@ -161,7 +161,7 @@ public void testInvalidateOpenedConnectionsWhenWriterHostChanged() throws SQLExc .build(); final HostSpec failoverTargetHost = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host2") .build(); - when(mockPluginService.getHosts()) + when(mockPluginService.getAllHosts()) .thenReturn(Collections.singletonList(originalHost)) .thenReturn(Collections.singletonList(failoverTargetHost)); when(mockSqlFunction.call()) diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointMonitorImplTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointMonitorImplTest.java new file mode 100644 index 000000000..668f92e8e --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointMonitorImplTest.java @@ -0,0 +1,134 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.customendpoint; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.sql.SQLException; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.awssdk.services.rds.model.DBClusterEndpoint; +import software.amazon.awssdk.services.rds.model.DescribeDbClusterEndpointsResponse; +import software.amazon.jdbc.AllowedAndBlockedHosts; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.hostavailability.HostAvailabilityStrategy; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +public class CustomEndpointMonitorImplTest { + @Mock private PluginService mockPluginService; + @Mock private BiFunction mockRdsClientFunc; + @Mock private RdsClient mockRdsClient; + @Mock private DescribeDbClusterEndpointsResponse mockDescribeResponse; + @Mock private DBClusterEndpoint mockClusterEndpoint1; + @Mock private DBClusterEndpoint mockClusterEndpoint2; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryCounter mockTelemetryCounter; + + private final String customEndpointUrl1 = "custom1.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"; + private final String customEndpointUrl2 = "custom2.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"; + private final String endpointId = "custom1"; + private final String clusterId = "cluster1"; + private final String endpointRoleType = "ANY"; + private List twoEndpointList; + private List oneEndpointList; + private final List staticMembersList = Arrays.asList("member1", "member2"); + private final Set staticMembersSet = new HashSet<>(staticMembersList); + private final CustomEndpointInfo expectedInfo = new CustomEndpointInfo( + endpointId, + clusterId, + customEndpointUrl1, + CustomEndpointRoleType.valueOf(endpointRoleType), + staticMembersSet, + MemberListType.STATIC_LIST); + + private AutoCloseable closeable; + private final HostAvailabilityStrategy availabilityStrategy = new SimpleHostAvailabilityStrategy(); + private final HostSpecBuilder hostSpecBuilder = new HostSpecBuilder(availabilityStrategy); + private final HostSpec host = hostSpecBuilder.host(customEndpointUrl1).build(); + + @BeforeEach + public void init() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + + twoEndpointList = Arrays.asList(mockClusterEndpoint1, mockClusterEndpoint2); + oneEndpointList = Collections.singletonList(mockClusterEndpoint1); + + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter(any(String.class))).thenReturn(mockTelemetryCounter); + when(mockRdsClientFunc.apply(any(HostSpec.class), any(Region.class))).thenReturn(mockRdsClient); + when(mockRdsClient.describeDBClusterEndpoints(any(Consumer.class))).thenReturn(mockDescribeResponse); + when(mockDescribeResponse.dbClusterEndpoints()).thenReturn(twoEndpointList).thenReturn(oneEndpointList); + when(mockClusterEndpoint1.endpoint()).thenReturn(customEndpointUrl1); + when(mockClusterEndpoint2.endpoint()).thenReturn(customEndpointUrl2); + when(mockClusterEndpoint1.hasStaticMembers()).thenReturn(true); + when(mockClusterEndpoint1.staticMembers()).thenReturn(staticMembersList); + when(mockClusterEndpoint1.dbClusterEndpointIdentifier()).thenReturn(endpointId); + when(mockClusterEndpoint1.dbClusterIdentifier()).thenReturn(clusterId); + when(mockClusterEndpoint1.customEndpointType()).thenReturn(endpointRoleType); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + CustomEndpointPlugin.monitors.clear(); + } + + @Test + public void testRun() throws InterruptedException { + CustomEndpointMonitorImpl monitor = new CustomEndpointMonitorImpl( + mockPluginService, host, endpointId, Region.US_EAST_1, TimeUnit.MILLISECONDS.toNanos(50), mockRdsClientFunc); + // Wait for 2 run cycles. The first will return an unexpected number of endpoints in the API response, the second + // will return the expected number of endpoints (one). + TimeUnit.MILLISECONDS.sleep(100); + assertEquals(expectedInfo, CustomEndpointMonitorImpl.customEndpointInfoCache.get(host.getHost())); + monitor.close(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(AllowedAndBlockedHosts.class); + verify(mockPluginService).setAllowedAndBlockedHosts(captor.capture()); + assertEquals(staticMembersSet, captor.getValue().getAllowedHostIds()); + assertNull(captor.getValue().getBlockedHostIds()); + + // Wait for monitor to close + TimeUnit.MILLISECONDS.sleep(50); + assertTrue(monitor.stop.get()); + verify(mockRdsClient, atLeastOnce()).close(); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java new file mode 100644 index 000000000..e6723d35b --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/customendpoint/CustomEndpointPluginTest.java @@ -0,0 +1,160 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.customendpoint; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static software.amazon.jdbc.plugin.customendpoint.CustomEndpointPlugin.WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS; + +import java.sql.Connection; +import java.sql.SQLException; +import java.sql.Statement; +import java.util.Properties; +import java.util.concurrent.TimeUnit; +import java.util.function.BiFunction; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.rds.RdsClient; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.hostavailability.HostAvailabilityStrategy; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.util.telemetry.TelemetryCounter; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; + +public class CustomEndpointPluginTest { + private final String writerClusterUrl = "writer.cluster-XYZ.us-east-1.rds.amazonaws.com"; + private final String customEndpointUrl = "custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"; + + private AutoCloseable closeable; + private final Properties props = new Properties(); + private final HostAvailabilityStrategy availabilityStrategy = new SimpleHostAvailabilityStrategy(); + private final HostSpecBuilder hostSpecBuilder = new HostSpecBuilder(availabilityStrategy); + private final HostSpec writerClusterHost = hostSpecBuilder.host(writerClusterUrl).build(); + private final HostSpec host = hostSpecBuilder.host(customEndpointUrl).build(); + + @Mock private PluginService mockPluginService; + @Mock private BiFunction mockRdsClientFunc; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryCounter mockTelemetryCounter; + @Mock private JdbcCallable mockConnectFunc; + @Mock private JdbcCallable mockJdbcMethodFunc; + @Mock private Connection mockConnection; + @Mock private CustomEndpointMonitor mockMonitor; + + @BeforeEach + public void init() throws SQLException { + closeable = MockitoAnnotations.openMocks(this); + + when(mockPluginService.getTelemetryFactory()).thenReturn(mockTelemetryFactory); + when(mockTelemetryFactory.createCounter(any(String.class))).thenReturn(mockTelemetryCounter); + when(mockMonitor.hasCustomEndpointInfo()).thenReturn(true); + } + + @AfterEach + void cleanUp() throws Exception { + closeable.close(); + props.clear(); + CustomEndpointPlugin.monitors.clear(); + } + + private CustomEndpointPlugin getSpyPlugin() { + CustomEndpointPlugin plugin = new CustomEndpointPlugin(mockPluginService, props, mockRdsClientFunc); + CustomEndpointPlugin spyPlugin = spy(plugin); + doReturn(mockMonitor).when(spyPlugin).createMonitorIfAbsent(any(Properties.class)); + return spyPlugin; + } + + @Test + public void testConnect_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { + CustomEndpointPlugin spyPlugin = getSpyPlugin(); + + spyPlugin.connect("", writerClusterHost, props, true, mockConnectFunc); + + verify(mockConnectFunc, times(1)).call(); + verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); + } + + @Test + public void testConnect_monitorCreated() throws SQLException { + CustomEndpointPlugin spyPlugin = getSpyPlugin(); + + spyPlugin.connect("", host, props, true, mockConnectFunc); + + verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); + verify(mockConnectFunc, times(1)).call(); + } + + @Test + public void testConnect_timeoutWaitingForInfo() throws SQLException { + WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.set(props, "1"); + CustomEndpointPlugin spyPlugin = getSpyPlugin(); + when(mockMonitor.hasCustomEndpointInfo()).thenReturn(false); + + assertThrows(SQLException.class, () -> spyPlugin.connect("", host, props, true, mockConnectFunc)); + + verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); + verify(mockConnectFunc, never()).call(); + } + + @Test + public void testExecute_monitorNotCreatedIfNotCustomEndpointHost() throws SQLException { + CustomEndpointPlugin spyPlugin = getSpyPlugin(); + + spyPlugin.execute( + Statement.class, SQLException.class, mockConnection, "Connection.createStatement", mockJdbcMethodFunc, null); + + verify(mockJdbcMethodFunc, times(1)).call(); + verify(spyPlugin, never()).createMonitorIfAbsent(any(Properties.class)); + } + + @Test + public void testExecute_monitorCreated() throws SQLException { + CustomEndpointPlugin spyPlugin = getSpyPlugin(); + spyPlugin.customEndpointHostSpec = host; + + spyPlugin.execute( + Statement.class, SQLException.class, mockConnection, "Connection.createStatement", mockJdbcMethodFunc, null); + + verify(spyPlugin, times(1)).createMonitorIfAbsent(eq(props)); + verify(mockJdbcMethodFunc, times(1)).call(); + } + + @Test + public void testCloseMonitors() throws Exception { + CustomEndpointPlugin.monitors.computeIfAbsent("test-monitor", (key) -> mockMonitor, TimeUnit.SECONDS.toNanos(30)); + + CustomEndpointPlugin.closeMonitors(); + + // close() may be called by the cleanup thread in addition to the call below. + verify(mockMonitor, atLeastOnce()).close(); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java index ee8bc0806..1685f326b 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/efm/ConcurrencyTests.java @@ -57,6 +57,7 @@ import org.checkerframework.checker.nullness.qual.Nullable; import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; +import software.amazon.jdbc.AllowedAndBlockedHosts; import software.amazon.jdbc.ConnectionPlugin; import software.amazon.jdbc.ConnectionProvider; import software.amazon.jdbc.HostListProvider; @@ -480,6 +481,11 @@ public EnumSet setCurrentConnection(@NonNull Connection conne return null; } + @Override + public List getAllHosts() { + return null; + } + @Override public List getHosts() { return null; @@ -490,6 +496,10 @@ public HostSpec getInitialConnectionHostSpec() { return null; } + @Override + public void setAllowedAndBlockedHosts(AllowedAndBlockedHosts allowedAndBlockedHosts) { + } + @Override public boolean acceptsStrategy(HostRole role, String strategy) { return false; diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java index 5742de669..515325f7c 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/ClusterAwareWriterFailoverHandlerTest.java @@ -95,7 +95,7 @@ public void testReconnectToWriter_taskBReaderException() throws SQLException { when(mockPluginService.forceConnect(refEq(readerA), eq(properties))).thenThrow(SQLException.class); when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - when(mockPluginService.getHosts()).thenReturn(topology); + when(mockPluginService.getAllHosts()).thenReturn(topology); when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())).thenThrow(SQLException.class); @@ -133,7 +133,7 @@ public void testReconnectToWriter_SlowReaderA() throws SQLException { when(mockPluginService.forceConnect(refEq(writer), eq(properties))).thenReturn(mockWriterConnection); when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); - when(mockPluginService.getHosts()).thenReturn(topology).thenReturn(newTopology); + when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) .thenAnswer( @@ -183,7 +183,7 @@ public void testReconnectToWriter_taskBDefers() throws SQLException { }); when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenThrow(SQLException.class); - when(mockPluginService.getHosts()).thenReturn(topology); + when(mockPluginService.getAllHosts()).thenReturn(topology); when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); @@ -231,7 +231,7 @@ public void testConnectToReaderA_SlowWriter() throws SQLException { when(mockPluginService.forceConnect(refEq(readerB), eq(properties))).thenReturn(mockReaderBConnection); when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenReturn(mockNewWriterConnection); - when(mockPluginService.getHosts()).thenReturn(newTopology); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); @@ -280,7 +280,7 @@ public void testConnectToReaderA_taskADefers() throws SQLException { }); final List newTopology = Arrays.asList(newWriterHost, writer, readerA, readerB); - when(mockPluginService.getHosts()).thenReturn(newTopology); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); @@ -334,7 +334,7 @@ public void testFailedToConnect_failoverTimeout() throws SQLException { Thread.sleep(30000); return mockNewWriterConnection; }); - when(mockPluginService.getHosts()).thenReturn(newTopology); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); @@ -381,7 +381,7 @@ public void testFailedToConnect_taskAException_taskBWriterException() throws SQL when(mockPluginService.forceConnect(refEq(newWriterHost), eq(properties))).thenThrow(exception); when(mockPluginService.isNetworkException(exception)).thenReturn(true); - when(mockPluginService.getHosts()).thenReturn(newTopology); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); when(mockReaderFailover.getReaderConnection(ArgumentMatchers.anyList())) .thenReturn(new ReaderFailoverResult(mockReaderAConnection, readerA, true)); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java index 030c56ee3..f1075d9fd 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/failover/FailoverConnectionPluginTest.java @@ -179,6 +179,8 @@ void test_updateTopology() throws SQLException { @ValueSource(booleans = {true, false}) void test_updateTopology_withForceUpdate(final boolean forceUpdate) throws SQLException { + when(mockPluginService.getAllHosts()).thenReturn(Collections.singletonList( + new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); when(mockPluginService.getHosts()).thenReturn(Collections.singletonList( new HostSpecBuilder(new SimpleHostAvailabilityStrategy()).host("host").build())); when(mockConnection.isClosed()).thenReturn(false); @@ -229,6 +231,7 @@ void test_failoverReader_withValidFailedHostSpec_successFailover() throws SQLExc when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); when(mockHostSpec.getRawAvailability()).thenReturn(HostAvailability.AVAILABLE); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockPluginService.getHosts()).thenReturn(hosts); when(mockReaderResult.isConnected()).thenReturn(true); when(mockReaderResult.getConnection()).thenReturn(mockConnection); @@ -259,6 +262,7 @@ void test_failoverReader_withVNoFailedHostSpec_withException() throws SQLExcepti when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); when(mockHostSpec.getAvailability()).thenReturn(HostAvailability.AVAILABLE); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockPluginService.getHosts()).thenReturn(hosts); when(mockReaderResult.getException()).thenReturn(new SQLException()); when(mockReaderResult.getHost()).thenReturn(hostSpec); @@ -282,6 +286,7 @@ void test_failoverWriter_failedFailover_throwsException() throws SQLException { final List hosts = Collections.singletonList(hostSpec); when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockPluginService.getHosts()).thenReturn(hosts); when(mockWriterResult.getException()).thenReturn(new SQLException()); @@ -304,6 +309,7 @@ void test_failoverWriter_failedFailover_withNoResult() throws SQLException { final List hosts = Collections.singletonList(hostSpec); when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockPluginService.getHosts()).thenReturn(hosts); when(mockWriterResult.isConnected()).thenReturn(false); @@ -330,6 +336,7 @@ void test_failoverWriter_successFailover() throws SQLException { final List hosts = Collections.singletonList(hostSpec); when(mockHostSpec.getAliases()).thenReturn(new HashSet<>(Arrays.asList("alias1", "alias2"))); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockPluginService.getHosts()).thenReturn(hosts); initializePlugin(); diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java index 40aa20c38..438fa73b2 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/readwritesplitting/ReadWriteSplittingPluginTest.java @@ -123,6 +123,7 @@ void cleanUp() throws Exception { void mockDefaultBehavior() throws SQLException { when(this.mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); when(this.mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); + when(this.mockPluginService.getAllHosts()).thenReturn(defaultHosts); when(this.mockPluginService.getHosts()).thenReturn(defaultHosts); when(this.mockPluginService.getHostSpecByStrategy(eq(HostRole.READER), eq("random"))) .thenReturn(readerHostSpec1); @@ -150,7 +151,7 @@ void mockDefaultBehavior() throws SQLException { @Test public void testSetReadOnly_trueFalse() throws SQLException { - when(this.mockPluginService.getHosts()).thenReturn(singleReaderTopology); + when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( @@ -183,7 +184,7 @@ public void testSetReadOnly_trueFalse() throws SQLException { @Test public void testSetReadOnlyTrue_alreadyOnReader() throws SQLException { - when(this.mockPluginService.getHosts()).thenReturn(singleReaderTopology); + when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); @@ -204,7 +205,7 @@ public void testSetReadOnlyTrue_alreadyOnReader() throws SQLException { @Test public void testSetReadOnlyFalse_alreadyOnWriter() throws SQLException { - when(this.mockPluginService.getHosts()).thenReturn(singleReaderTopology); + when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.getCurrentConnection()).thenReturn(mockWriterConn); when(mockPluginService.getCurrentHostSpec()).thenReturn(writerHostSpec); @@ -226,7 +227,7 @@ public void testSetReadOnlyFalse_alreadyOnWriter() throws SQLException { public void testSetReadOnly_falseInTransaction() { when(this.mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); when(this.mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); - when(this.mockPluginService.getHosts()).thenReturn(singleReaderTopology); + when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.isInTransaction()).thenReturn(true); final ReadWriteSplittingPlugin plugin = new ReadWriteSplittingPlugin( @@ -288,7 +289,7 @@ public void testSetReadOnly_true_oneHost() throws SQLException { public void testSetReadOnly_false_writerConnectionFails() throws SQLException { when(mockPluginService.connect(eq(writerHostSpec), eq(defaultProps))) .thenThrow(SQLException.class); - when(this.mockPluginService.getHosts()).thenReturn(singleReaderTopology); + when(this.mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.getCurrentConnection()).thenReturn(mockReaderConn1); when(mockPluginService.getCurrentHostSpec()).thenReturn(readerHostSpec1); diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/RdsUtilsTests.java b/wrapper/src/test/java/software/amazon/jdbc/util/RdsUtilsTests.java index a5ee6effe..caefb7f88 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/RdsUtilsTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/RdsUtilsTests.java @@ -76,7 +76,6 @@ public class RdsUtilsTests { private static final String usEastRegionElbUrl = "elb-name.elb.us-east-2.amazonaws.com"; - private static final String usIsobEastRegionCluster = "database-test-name.cluster-XYZ.rds.us-isob-east-1.sc2s.sgov.gov"; private static final String usIsobEastRegionClusterReadOnly = @@ -479,4 +478,43 @@ public void testRemoveGreenInstancePrefix() { assertEquals("test-instance-green-123456.domain.com", target.removeGreenInstancePrefix("test-instance-green-123456-green-123456.domain.com")); } + + @Test + public void testGetRdsClusterId() { + assertEquals("database-test-name", target.getRdsClusterId(usEastRegionCluster)); + assertEquals("database-test-name", target.getRdsClusterId(usEastRegionClusterReadOnly)); + assertNull(target.getRdsClusterId(usEastRegionInstance)); + assertEquals("proxy-test-name", target.getRdsClusterId(usEastRegionProxy)); + assertEquals("custom-test-name", target.getRdsClusterId(usEastRegionCustomDomain)); + assertEquals("database-test-name", target.getRdsClusterId(usEastRegionLimitlessDbShardGroup)); + + assertEquals("database-test-name", target.getRdsClusterId(chinaRegionCluster)); + assertEquals("database-test-name", target.getRdsClusterId(chinaRegionClusterReadOnly)); + assertNull(target.getRdsClusterId(chinaRegionInstance)); + assertEquals("proxy-test-name", target.getRdsClusterId(chinaRegionProxy)); + assertEquals("custom-test-name", target.getRdsClusterId(chinaRegionCustomDomain)); + assertEquals("database-test-name", target.getRdsClusterId(chinaRegionLimitlessDbShardGroup)); + + assertEquals("database-test-name", target.getRdsClusterId(oldChinaRegionCluster)); + assertEquals("database-test-name", target.getRdsClusterId(oldChinaRegionClusterReadOnly)); + assertNull(target.getRdsClusterId(oldChinaRegionInstance)); + assertEquals("proxy-test-name", target.getRdsClusterId(oldChinaRegionProxy)); + assertEquals("custom-test-name", target.getRdsClusterId(oldChinaRegionCustomDomain)); + assertEquals("database-test-name", target.getRdsClusterId(oldChinaRegionLimitlessDbShardGroup)); + + assertEquals("database-test-name", target.getRdsClusterId(usIsobEastRegionCluster)); + assertEquals("database-test-name", target.getRdsClusterId(usIsobEastRegionClusterReadOnly)); + assertNull(target.getRdsClusterId(usIsobEastRegionInstance)); + assertEquals("proxy-test-name", target.getRdsClusterId(usIsobEastRegionProxy)); + assertEquals("custom-test-name", target.getRdsClusterId(usIsobEastRegionCustomDomain)); + assertEquals("database-test-name", target.getRdsClusterId(usIsobEastRegionLimitlessDbShardGroup)); + + assertEquals("database-test-name", target.getRdsClusterId(usGovEastRegionCluster)); + assertEquals("database-test-name", target.getRdsClusterId(usIsoEastRegionCluster)); + assertEquals("database-test-name", target.getRdsClusterId(usIsoEastRegionClusterReadOnly)); + assertNull(target.getRdsClusterId(usIsoEastRegionInstance)); + assertEquals("proxy-test-name", target.getRdsClusterId(usIsoEastRegionProxy)); + assertEquals("custom-test-name", target.getRdsClusterId(usIsoEastRegionCustomDomain)); + assertEquals("database-test-name", target.getRdsClusterId(usIsoEastRegionLimitlessDbShardGroup)); + } }