From b879640dfac11f978d32d0e12cb3ee3f8a8f6ca4 Mon Sep 17 00:00:00 2001 From: Daniel Frankcom Date: Wed, 16 Jul 2025 11:01:25 -0700 Subject: [PATCH] add plugin for DSQL auth tokens --- wrapper/build.gradle.kts | 2 + .../jdbc/ConnectionPluginChainBuilder.java | 3 + .../federatedauth/FederatedAuthPlugin.java | 2 +- .../plugin/federatedauth/OktaAuthPlugin.java | 2 +- .../iam/DsqlIamConnectionPluginFactory.java | 34 +++ .../jdbc/plugin/iam/DsqlTokenUtility.java | 53 +++++ .../plugin/iam/IamAuthConnectionPlugin.java | 6 +- .../iam/IamAuthConnectionPluginFactory.java | 3 +- .../amazon/jdbc/util/IamAuthUtils.java | 12 +- .../software/amazon/jdbc/util/RdsUtils.java | 26 ++- ..._advanced_jdbc_wrapper_messages.properties | 1 + .../iam/DsqlIamConnectionPluginTest.java | 205 ++++++++++++++++++ .../jdbc/plugin/iam/DsqlTokenUtilityTest.java | 134 ++++++++++++ .../amazon/jdbc/util/RdsUtilsTests.java | 35 +++ 14 files changed, 508 insertions(+), 10 deletions(-) create mode 100644 wrapper/src/main/java/software/amazon/jdbc/plugin/iam/DsqlIamConnectionPluginFactory.java create mode 100644 wrapper/src/main/java/software/amazon/jdbc/plugin/iam/DsqlTokenUtility.java create mode 100644 wrapper/src/test/java/software/amazon/jdbc/plugin/iam/DsqlIamConnectionPluginTest.java create mode 100644 wrapper/src/test/java/software/amazon/jdbc/plugin/iam/DsqlTokenUtilityTest.java diff --git a/wrapper/build.gradle.kts b/wrapper/build.gradle.kts index f1f544545..469cb5ce9 100644 --- a/wrapper/build.gradle.kts +++ b/wrapper/build.gradle.kts @@ -37,6 +37,7 @@ dependencies { compileOnly("software.amazon.awssdk:rds:2.31.78") compileOnly("software.amazon.awssdk:auth:2.31.45") // Required for IAM (light implementation) compileOnly("software.amazon.awssdk:http-client-spi:2.31.60") // Required for IAM (light implementation) + compileOnly("software.amazon.awssdk:dsql:2.31.78") compileOnly("software.amazon.awssdk:sts:2.31.78") compileOnly("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8 compileOnly("com.mchange:c3p0:0.11.0") @@ -73,6 +74,7 @@ dependencies { testImplementation("software.amazon.awssdk:rds:2.31.78") testImplementation("software.amazon.awssdk:auth:2.31.45") // Required for IAM (light implementation) testImplementation("software.amazon.awssdk:http-client-spi:2.31.60") // Required for IAM (light implementation) + testImplementation("software.amazon.awssdk:dsql:2.31.78") testImplementation("software.amazon.awssdk:ec2:2.31.78") testImplementation("software.amazon.awssdk:secretsmanager:2.31.12") testImplementation("software.amazon.awssdk:sts:2.31.78") diff --git a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java index 411e40cd8..3a14cc7ba 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java +++ b/wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java @@ -45,6 +45,7 @@ import software.amazon.jdbc.plugin.failover.FailoverConnectionPluginFactory; import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPluginFactory; import software.amazon.jdbc.plugin.federatedauth.OktaAuthPluginFactory; +import software.amazon.jdbc.plugin.iam.DsqlIamConnectionPluginFactory; import software.amazon.jdbc.plugin.iam.IamAuthConnectionPluginFactory; import software.amazon.jdbc.plugin.limitless.LimitlessConnectionPluginFactory; import software.amazon.jdbc.plugin.readwritesplitting.ReadWriteSplittingPluginFactory; @@ -75,6 +76,7 @@ public class ConnectionPluginChainBuilder { put("failover", new FailoverConnectionPluginFactory()); put("failover2", new software.amazon.jdbc.plugin.failover2.FailoverConnectionPluginFactory()); put("iam", new IamAuthConnectionPluginFactory()); + put("iamDsql", new DsqlIamConnectionPluginFactory()); put("awsSecretsManager", new AwsSecretsManagerConnectionPluginFactory()); put("federatedAuth", new FederatedAuthPluginFactory()); put("okta", new OktaAuthPluginFactory()); @@ -114,6 +116,7 @@ public class ConnectionPluginChainBuilder { put(FastestResponseStrategyPluginFactory.class, 900); put(LimitlessConnectionPluginFactory.class, 950); put(IamAuthConnectionPluginFactory.class, 1000); + put(DsqlIamConnectionPluginFactory.class, 1001); put(AwsSecretsManagerConnectionPluginFactory.class, 1100); put(FederatedAuthPluginFactory.class, 1200); put(LogQueryConnectionPluginFactory.class, 1300); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java index 6a06e1d68..3f00426f3 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/FederatedAuthPlugin.java @@ -121,7 +121,7 @@ public Set getSubscribedMethods() { public FederatedAuthPlugin(final PluginService pluginService, final CredentialsProviderFactory credentialsProviderFactory) { - this(pluginService, credentialsProviderFactory, new RdsUtils(), IamAuthUtils.getTokenUtility()); + this(pluginService, credentialsProviderFactory, new RdsUtils(), IamAuthUtils.getRdsTokenUtility()); } FederatedAuthPlugin( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java index 7e47ff479..88c5592d9 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/federatedauth/OktaAuthPlugin.java @@ -95,7 +95,7 @@ public class OktaAuthPlugin extends AbstractConnectionPlugin { private final TelemetryCounter fetchTokenCounter; public OktaAuthPlugin(PluginService pluginService, CredentialsProviderFactory credentialsProviderFactory) { - this(pluginService, credentialsProviderFactory, new RdsUtils(), IamAuthUtils.getTokenUtility()); + this(pluginService, credentialsProviderFactory, new RdsUtils(), IamAuthUtils.getRdsTokenUtility()); } OktaAuthPlugin( diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/DsqlIamConnectionPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/DsqlIamConnectionPluginFactory.java new file mode 100644 index 000000000..5b6ed3148 --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/DsqlIamConnectionPluginFactory.java @@ -0,0 +1,34 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import java.util.Properties; +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.jdbc.ConnectionPlugin; +import software.amazon.jdbc.ConnectionPluginFactory; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.IamAuthUtils; + +/** + * Provides {@link ConnectionPlugin} instances which can be used to connect to Amazon Aurora DSQL. + */ +public class DsqlIamConnectionPluginFactory implements ConnectionPluginFactory { + @Override + public ConnectionPlugin getInstance(@NonNull final PluginService pluginService, @NonNull final Properties props) { + return new IamAuthConnectionPlugin(pluginService, IamAuthUtils.getDsqlTokenUtility()); + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/DsqlTokenUtility.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/DsqlTokenUtility.java new file mode 100644 index 000000000..64968605a --- /dev/null +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/DsqlTokenUtility.java @@ -0,0 +1,53 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import org.checkerframework.checker.nullness.qual.NonNull; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.dsql.DsqlUtilities; + +/** + * Represents an {@link IamTokenUtility} which provides auth tokens for connecting to Amazon Aurora DSQL. + */ +public class DsqlTokenUtility implements IamTokenUtility { + + public DsqlTokenUtility() { } + + @Override + public String generateAuthenticationToken( + @NonNull final AwsCredentialsProvider credentialsProvider, + @NonNull final Region region, + @NonNull final String hostname, + final int port, + @NonNull final String username) { + final DsqlUtilities utilities = DsqlUtilities.builder() + .credentialsProvider(credentialsProvider) + .region(region) + .build(); + + if (username.equals("admin")) { + return utilities.generateDbConnectAdminAuthToken((builder) -> + builder.hostname(hostname).region(region) + ); + } else { + return utilities.generateDbConnectAuthToken((builder) -> + builder.hostname(hostname).region(region) + ); + } + } +} diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java index 5541ac917..afbdc535f 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPlugin.java @@ -86,11 +86,7 @@ public class IamAuthConnectionPlugin extends AbstractConnectionPlugin { private final IamTokenUtility iamTokenUtility; - public IamAuthConnectionPlugin(final @NonNull PluginService pluginService) { - this(pluginService, IamAuthUtils.getTokenUtility()); - } - - IamAuthConnectionPlugin(final @NonNull PluginService pluginService, IamTokenUtility utility) { + public IamAuthConnectionPlugin(final @NonNull PluginService pluginService, final IamTokenUtility utility) { this.iamTokenUtility = utility; this.pluginService = pluginService; this.telemetryFactory = pluginService.getTelemetryFactory(); diff --git a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginFactory.java b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginFactory.java index 7173ca249..262a21b94 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginFactory.java +++ b/wrapper/src/main/java/software/amazon/jdbc/plugin/iam/IamAuthConnectionPluginFactory.java @@ -20,10 +20,11 @@ import software.amazon.jdbc.ConnectionPlugin; import software.amazon.jdbc.ConnectionPluginFactory; import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.util.IamAuthUtils; public class IamAuthConnectionPluginFactory implements ConnectionPluginFactory { @Override public ConnectionPlugin getInstance(final PluginService pluginService, final Properties props) { - return new IamAuthConnectionPlugin(pluginService); + return new IamAuthConnectionPlugin(pluginService, IamAuthUtils.getRdsTokenUtility()); } } diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java index 55e5e181a..3aebc9d4a 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/IamAuthUtils.java @@ -21,6 +21,7 @@ import software.amazon.awssdk.regions.Region; import software.amazon.jdbc.HostSpec; import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.plugin.iam.DsqlTokenUtility; import software.amazon.jdbc.plugin.iam.IamTokenUtility; import software.amazon.jdbc.plugin.iam.LightRdsUtility; import software.amazon.jdbc.plugin.iam.RegularRdsUtility; @@ -59,7 +60,7 @@ public static String getCacheKey( return String.format("%s:%s:%d:%s", region, hostname, port, user); } - public static IamTokenUtility getTokenUtility() { + public static IamTokenUtility getRdsTokenUtility() { try { // RegularRdsUtility requires AWS Java SDK RDS v2.x to be presented in classpath. Class.forName("software.amazon.awssdk.services.rds.RdsUtilities"); @@ -81,6 +82,15 @@ public static IamTokenUtility getTokenUtility() { } } + public static IamTokenUtility getDsqlTokenUtility() { + try { + Class.forName("software.amazon.awssdk.services.dsql.DsqlUtilities"); + return new DsqlTokenUtility(); + } catch (final ClassNotFoundException e) { + throw new RuntimeException(Messages.get("AuthenticationToken.javaDsqlSdkNotInClasspath"), e); + } + } + public static String generateAuthenticationToken( final IamTokenUtility tokenUtils, final PluginService pluginService, diff --git a/wrapper/src/main/java/software/amazon/jdbc/util/RdsUtils.java b/wrapper/src/main/java/software/amazon/jdbc/util/RdsUtils.java index e9177a277..076e894c5 100644 --- a/wrapper/src/main/java/software/amazon/jdbc/util/RdsUtils.java +++ b/wrapper/src/main/java/software/amazon/jdbc/util/RdsUtils.java @@ -144,6 +144,14 @@ public class RdsUtils { + "\\.(amazonaws\\.com\\.?|c2s\\.ic\\.gov\\.?|sc2s\\.sgov\\.gov\\.?))$", Pattern.CASE_INSENSITIVE); + private static final Pattern AURORA_DSQL_CLUSTER_PATTERN = + Pattern.compile( + "^(?[^.]+)\\." + + "(?dsql(?:-[^.]+)?)\\." + + "(?(?[a-zA-Z0-9\\-]+)" + + "\\.on\\.aws\\.?)$", + Pattern.CASE_INSENSITIVE); + private static final Pattern ELB_PATTERN = Pattern.compile( "^(?.+)\\.elb\\." @@ -259,6 +267,16 @@ public String getRdsInstanceHostPattern(final String host) { return group == null ? "?" : "?." + group; } + public String getDsqlInstanceId(final String host) { + final String preparedHost = getPreparedHost(host); + if (StringUtils.isNullOrEmpty(preparedHost)) { + return null; + } + + final Matcher matcher = cacheMatcher(preparedHost, AURORA_DSQL_CLUSTER_PATTERN); + return getRegexGroup(matcher, INSTANCE_GROUP); + } + public String getRdsRegion(final String host) { final String preparedHost = getPreparedHost(host); if (StringUtils.isNullOrEmpty(preparedHost)) { @@ -266,7 +284,8 @@ public String getRdsRegion(final String host) { } final Matcher matcher = cacheMatcher(preparedHost, - AURORA_DNS_PATTERN, AURORA_CHINA_DNS_PATTERN, AURORA_OLD_CHINA_DNS_PATTERN, AURORA_GOV_DNS_PATTERN); + AURORA_DNS_PATTERN, AURORA_CHINA_DNS_PATTERN, AURORA_OLD_CHINA_DNS_PATTERN, AURORA_GOV_DNS_PATTERN, + AURORA_DSQL_CLUSTER_PATTERN); final String group = getRegexGroup(matcher, REGION_GROUP); if (group != null) { return group; @@ -294,6 +313,11 @@ public boolean isLimitlessDbShardGroupDns(final String host) { return dnsGroup != null && dnsGroup.equalsIgnoreCase("shardgrp-"); } + public boolean isDsqlCluster(final String host) { + final String instanceId = getDsqlInstanceId(host); + return instanceId != null; + } + public String getRdsClusterHostUrl(final String host) { final String preparedHost = getPreparedHost(host); if (StringUtils.isNullOrEmpty(preparedHost)) { diff --git a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties index f60196f1b..2d6c00c3b 100644 --- a/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties +++ b/wrapper/src/main/resources/aws_advanced_jdbc_wrapper_messages.properties @@ -47,6 +47,7 @@ AwsSdk.unsupportedRegion=Unsupported AWS region ''{0}''. For supported regions p AwsSecretsManagerConnectionPlugin.endpointOverrideMisconfigured=The provided endpoint is invalid and could not be used to create a URI: `{0}`. AwsSecretsManagerConnectionPlugin.endpointOverrideInvalidConnection=A connection to the provided endpoint could not be established: `{0}`. AwsSecretsManagerConnectionPlugin.javaSdkNotInClasspath=Required dependency 'AWS Java SDK for AWS Secrets Manager' is not on the classpath. +AuthenticationToken.javaDsqlSdkNotInClasspath=Required dependency 'AWS Java SDK for DSQL v2.x' is not on the classpath. AwsSecretsManagerConnectionPlugin.jacksonDatabindNotInClasspath=Required dependency 'Jackson Databind' is not on the classpath. AwsSecretsManagerConnectionPlugin.failedToFetchDbCredentials=Was not able to either fetch or read the database credentials from AWS Secrets Manager. Ensure the correct secretId and region properties have been provided. AwsSecretsManagerConnectionPlugin.missingRequiredConfigParameter=Configuration parameter ''{0}'' is required. diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/DsqlIamConnectionPluginTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/DsqlIamConnectionPluginTest.java new file mode 100644 index 000000000..794232c7a --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/DsqlIamConnectionPluginTest.java @@ -0,0 +1,205 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mockStatic; +import static software.amazon.jdbc.plugin.iam.DsqlTokenUtilityTest.ADMIN_USER; +import static software.amazon.jdbc.plugin.iam.DsqlTokenUtilityTest.REGULAR_USER; + +import java.sql.Connection; +import java.sql.SQLException; +import java.util.List; +import java.util.Properties; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.identity.spi.AwsCredentialsIdentity; +import software.amazon.awssdk.regions.Region; +import software.amazon.jdbc.ConnectionPlugin; +import software.amazon.jdbc.ConnectionPluginChainBuilder; +import software.amazon.jdbc.ConnectionProvider; +import software.amazon.jdbc.HostSpec; +import software.amazon.jdbc.HostSpecBuilder; +import software.amazon.jdbc.JdbcCallable; +import software.amazon.jdbc.PluginManagerService; +import software.amazon.jdbc.PluginService; +import software.amazon.jdbc.PropertyDefinition; +import software.amazon.jdbc.authentication.AwsCredentialsManager; +import software.amazon.jdbc.dialect.Dialect; +import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy; +import software.amazon.jdbc.plugin.TokenInfo; +import software.amazon.jdbc.util.FullServicesContainer; +import software.amazon.jdbc.util.IamAuthUtils; +import software.amazon.jdbc.util.telemetry.TelemetryContext; +import software.amazon.jdbc.util.telemetry.TelemetryFactory; +import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel; + +class DsqlIamConnectionPluginTest { + + private static final Region TEST_REGION = Region.US_EAST_1; + private static final String TEST_HOSTNAME = String.format("foo0bar1baz2quux3quuux4.dsql.%s.on.aws", TEST_REGION); + private static final int TEST_PORT = 5432; + + private static final String DRIVER_PROTOCOL = "jdbc:postgresql:"; + private static final HostSpec HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy()) + .host(TEST_HOSTNAME).port(TEST_PORT).build(); + + private static final String DEFAULT_USERNAME = "admin"; + + private final Properties props = new Properties(); + + static void assertTokenContainsProperties(final String token, final String username) { + assertNotNull(token); + assertFalse(token.isEmpty()); + + assertTrue(token.contains(TEST_HOSTNAME)); + + final String expectedAction; + if (username.equals("admin")) { + expectedAction = "DbConnectAdmin"; + } else { + expectedAction = "DbConnect"; + } + + // Include the ampersand to ensure the complete action is compared. + assertTrue(token.contains("Action=" + expectedAction + "&")); + } + + private AutoCloseable cleanMocksCallback; + @Mock private Connection mockConnection; + @Mock private PluginService mockPluginService; + @Mock private FullServicesContainer mockServicesContainer; + @Mock private Dialect mockDialect; + @Mock private TelemetryFactory mockTelemetryFactory; + @Mock private TelemetryContext mockTelemetryContext; + @Mock private JdbcCallable mockLambda; + @Mock private ConnectionProvider mockConnectionProvider; + @Mock private PluginManagerService mockPluginManagerService; + @Mock private AwsCredentialsProvider mockCredentialsProvider; + @Mock private CompletableFuture completableFuture; + @Mock private AwsCredentialsIdentity mockCredentialsIdentity; + + private MockedStatic mockAwsCredsManagerClass; + + @BeforeEach + public void setup() throws ExecutionException, InterruptedException { + cleanMocksCallback = MockitoAnnotations.openMocks(this); + + IamAuthConnectionPlugin.clearCache(); + + props.setProperty(PropertyDefinition.USER.name, DEFAULT_USERNAME); + props.setProperty("iamRegion", Region.US_EAST_1.toString()); + props.setProperty(PropertyDefinition.PLUGINS.name, "iamDsql"); + + doReturn(mockPluginService).when(mockServicesContainer).getPluginService(); + doReturn(mockDialect).when(mockPluginService).getDialect(); + doReturn(TEST_PORT).when(mockDialect).getDefaultPort(); + doReturn(mockTelemetryFactory).when(mockPluginService).getTelemetryFactory(); + doReturn(mockTelemetryContext).when(mockTelemetryFactory) + .openTelemetryContext(anyString(), eq(TelemetryTraceLevel.NESTED)); + + // Intercept calls to get AWS credentials, and provide mocked equivalents. + mockAwsCredsManagerClass = mockStatic(AwsCredentialsManager.class); + mockAwsCredsManagerClass + .when(() -> AwsCredentialsManager.getProvider(any(), any())) + .thenReturn(mockCredentialsProvider); + + doReturn(completableFuture).when(mockCredentialsProvider).resolveIdentity(); + doReturn(mockCredentialsIdentity).when(completableFuture).get(); + + // These must return non-null values in order for signing to proceed. + doReturn("accessKeyId").when(mockCredentialsIdentity).accessKeyId(); + doReturn("secretAccessKey").when(mockCredentialsIdentity).secretAccessKey(); + } + + @AfterEach + public void tearDown() throws Exception { + cleanMocksCallback.close(); + if (mockAwsCredsManagerClass != null) { + mockAwsCredsManagerClass.close(); + } + } + + @SuppressWarnings("resource") // Prevent Mockito warning when mocking closeable return type. + private void assertPluginProvidesDsqlTokens(final ConnectionPlugin plugin, final String username) + throws SQLException { + Mockito.doReturn(mockConnection).when(mockLambda).call(); + + plugin + .connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda) + .close(); + + final String cacheKey = IamAuthUtils.getCacheKey( + username, + TEST_HOSTNAME, + TEST_PORT, + TEST_REGION); + + final TokenInfo info = IamAuthCacheHolder.tokenCache.get(cacheKey); + final String token = info.getToken(); + + assertTokenContainsProperties(token, username); + } + + @Test + public void testDsqlPluginRegistration() throws SQLException { + ConnectionPluginChainBuilder builder = new ConnectionPluginChainBuilder(); + + final List result = builder.getPlugins( + mockServicesContainer, + mockConnectionProvider, + null, + mockPluginManagerService, + props, + null); + + // 2 because default plugin is always included. + assertEquals(2, result.size()); + final ConnectionPlugin plugin = result.get(0); + + assertInstanceOf(IamAuthConnectionPlugin.class, plugin); + assertPluginProvidesDsqlTokens(plugin, DEFAULT_USERNAME); + } + + @ParameterizedTest + @ValueSource(strings = {REGULAR_USER, ADMIN_USER}) + public void testDsqlTokenGeneratedBasedOnUser(final String username) throws SQLException { + props.setProperty(PropertyDefinition.USER.name, username); + + final DsqlIamConnectionPluginFactory factory = new DsqlIamConnectionPluginFactory(); + final ConnectionPlugin plugin = factory.getInstance(mockPluginService, props); + assertPluginProvidesDsqlTokens(plugin, username); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/DsqlTokenUtilityTest.java b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/DsqlTokenUtilityTest.java new file mode 100644 index 000000000..d7f4cbcea --- /dev/null +++ b/wrapper/src/test/java/software/amazon/jdbc/plugin/iam/DsqlTokenUtilityTest.java @@ -0,0 +1,134 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package software.amazon.jdbc.plugin.iam; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mockStatic; +import static org.mockito.Mockito.verify; + +import java.util.function.Consumer; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Captor; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; +import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; +import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.services.dsql.DsqlUtilities; +import software.amazon.awssdk.services.dsql.model.GenerateAuthTokenRequest; + +class DsqlTokenUtilityTest { + + static final String REGULAR_USER = "testUser"; + static final String ADMIN_USER = "admin"; + + private static final Region TEST_REGION = Region.US_EAST_1; + private static final String TEST_HOSTNAME = String.format("foo0bar1baz2quux3quuux4.dsql.%s.on.aws", TEST_REGION); + private static final int TEST_PORT = 5432; + + private AutoCloseable cleanMocksCallback; + @Mock private DsqlUtilities mockDsqlUtilities; + @Mock private AwsCredentialsProvider mockCredentialsProvider; + @Mock private DsqlUtilities.Builder mockBuilder; + @Captor private ArgumentCaptor> captor; + + private MockedStatic mockDsqlUtilitiesClass; + + @BeforeEach + public void setup() throws Exception { + cleanMocksCallback = MockitoAnnotations.openMocks(this); + + mockDsqlUtilitiesClass = mockStatic(DsqlUtilities.class); + mockDsqlUtilitiesClass.when(DsqlUtilities::builder).thenReturn(mockBuilder); + + doReturn(mockBuilder).when(mockBuilder).credentialsProvider(any()); + doReturn(mockBuilder).when(mockBuilder).region(any()); + doReturn(mockDsqlUtilities).when(mockBuilder).build(); + } + + @AfterEach + public void tearDown() throws Exception { + cleanMocksCallback.close(); + if (mockDsqlUtilitiesClass != null) { + mockDsqlUtilitiesClass.close(); + } + } + + @ParameterizedTest + @ValueSource(strings = {REGULAR_USER, ADMIN_USER}) + public void testTokenUtilityCallsCorrectAuthMethod(final String username) { + final DsqlTokenUtility tokenUtility = new DsqlTokenUtility(); + + final String expectedToken = "test-token"; + + if (username.equals("admin")) { + doReturn(expectedToken) + .when(mockDsqlUtilities) + .generateDbConnectAdminAuthToken(ArgumentMatchers.>any()); + } else { + doReturn(expectedToken) + .when(mockDsqlUtilities) + .generateDbConnectAuthToken(ArgumentMatchers.>any()); + } + + final String actualToken = tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, TEST_REGION, TEST_HOSTNAME, TEST_PORT, username); + + assertEquals(expectedToken, actualToken); + } + + @ParameterizedTest + @ValueSource(strings = {REGULAR_USER, ADMIN_USER}) + public void testTokenRequestHasProvidedProperties(final String username) { + final DsqlTokenUtility tokenUtility = new DsqlTokenUtility(); + + tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, TEST_REGION, TEST_HOSTNAME, TEST_PORT, username); + + if (username.equals("admin")) { + verify(mockDsqlUtilities).generateDbConnectAdminAuthToken(captor.capture()); + } else { + verify(mockDsqlUtilities).generateDbConnectAuthToken(captor.capture()); + } + + final GenerateAuthTokenRequest.Builder builder = GenerateAuthTokenRequest.builder(); + captor.getValue().accept(builder); + final GenerateAuthTokenRequest request = builder.build(); + + assertEquals(TEST_HOSTNAME, request.hostname()); + assertEquals(TEST_REGION, request.region()); + } + + @ParameterizedTest + @ValueSource(strings = {REGULAR_USER, ADMIN_USER}) + public void testBuilderConfiguredWithProvidedProperties(final String username) { + final DsqlTokenUtility tokenUtility = new DsqlTokenUtility(); + tokenUtility.generateAuthenticationToken( + mockCredentialsProvider, TEST_REGION, TEST_HOSTNAME, TEST_PORT, username); + + verify(mockBuilder).credentialsProvider(mockCredentialsProvider); + verify(mockBuilder).region(TEST_REGION); + verify(mockBuilder).build(); + } +} diff --git a/wrapper/src/test/java/software/amazon/jdbc/util/RdsUtilsTests.java b/wrapper/src/test/java/software/amazon/jdbc/util/RdsUtilsTests.java index 59ee12619..e8423f222 100644 --- a/wrapper/src/test/java/software/amazon/jdbc/util/RdsUtilsTests.java +++ b/wrapper/src/test/java/software/amazon/jdbc/util/RdsUtilsTests.java @@ -116,6 +116,15 @@ public class RdsUtilsTests { private static final String usIsoEastRegionLimitlessDbShardGroup = "database-test-name.shardgrp-XYZ.rds.us-iso-east-1.c2s.ic.gov"; + private static final String usEastRegionDsqlInstance = + "dsql-cluster-identifier-1.dsql.us-east-2.on.aws"; + private static final String usEastRegionDsqlInstanceTrailingDot = + "dsql-cluster-identifier-2.dsql.us-east-2.on.aws."; + private static final String usWestRegionDsqlInstance = + "dsql-cluster-identifier-3.dsql.us-west-1.on.aws"; + private static final String usWestRegionDsqlGammaInstance = + "dsql-cluster-identifier-4.dsql-gamma.us-west-1.on.aws"; + @BeforeEach public void setupTests() { RdsUtils.clearCache(); @@ -166,6 +175,26 @@ public void testIsRdsDns() { assertTrue(target.isRdsDns(usIsoEastRegionLimitlessDbShardGroup)); } + @Test + public void testGetDsqlInstanceId() { + assertEquals("dsql-cluster-identifier-1", target.getDsqlInstanceId(usEastRegionDsqlInstance)); + assertEquals("dsql-cluster-identifier-2", target.getDsqlInstanceId(usEastRegionDsqlInstanceTrailingDot)); + assertEquals("dsql-cluster-identifier-3", target.getDsqlInstanceId(usWestRegionDsqlInstance)); + assertEquals("dsql-cluster-identifier-4", target.getDsqlInstanceId(usWestRegionDsqlGammaInstance)); + } + + @Test + public void testIsDsqlCluster() { + assertTrue(target.isDsqlCluster(usEastRegionDsqlInstance)); + assertTrue(target.isDsqlCluster(usEastRegionDsqlInstanceTrailingDot)); + assertTrue(target.isDsqlCluster(usWestRegionDsqlInstance)); + assertTrue(target.isDsqlCluster(usWestRegionDsqlGammaInstance)); + + assertFalse(target.isDsqlCluster(usIsobEastRegionCluster)); + assertFalse(target.isDsqlCluster(usEastRegionProxy)); + assertFalse(target.isDsqlCluster("https://www.amazon.com")); + } + @Test public void testGetRdsInstanceHostPattern() { final String expectedHostPattern = "?.XYZ.us-east-2.rds.amazonaws.com"; @@ -378,6 +407,12 @@ public void testGetRdsRegion() { assertEquals(expectedHostPattern, target.getRdsRegion(usEastRegionCustomDomain)); assertEquals(expectedHostPattern, target.getRdsRegion(usEastRegionElbUrl)); assertEquals(expectedHostPattern, target.getRdsRegion(usEastRegionLimitlessDbShardGroup)); + assertEquals(expectedHostPattern, target.getRdsRegion(usEastRegionDsqlInstance)); + assertEquals(expectedHostPattern, target.getRdsRegion(usEastRegionDsqlInstanceTrailingDot)); + + final String westExpectedHostPattern = "us-west-1"; + assertEquals(westExpectedHostPattern, target.getRdsRegion(usWestRegionDsqlInstance)); + assertEquals(westExpectedHostPattern, target.getRdsRegion(usWestRegionDsqlGammaInstance)); final String govExpectedHostPattern = "us-gov-east-1"; assertEquals(govExpectedHostPattern, target.getRdsRegion(usGovEastRegionCluster));