Skip to content

Commit e07c358

Browse files
add plugin for DSQL auth tokens
1 parent 3b32ac2 commit e07c358

File tree

8 files changed

+464
-1
lines changed

8 files changed

+464
-1
lines changed

wrapper/build.gradle.kts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ dependencies {
3737
compileOnly("software.amazon.awssdk:rds:2.31.78")
3838
compileOnly("software.amazon.awssdk:auth:2.31.45") // Required for IAM (light implementation)
3939
compileOnly("software.amazon.awssdk:http-client-spi:2.31.60") // Required for IAM (light implementation)
40+
compileOnly("software.amazon.awssdk:dsql:2.31.78")
4041
compileOnly("software.amazon.awssdk:sts:2.31.78")
4142
compileOnly("com.zaxxer:HikariCP:4.0.3") // Version 4.+ is compatible with Java 8
4243
compileOnly("com.mchange:c3p0:0.11.0")
@@ -73,6 +74,7 @@ dependencies {
7374
testImplementation("software.amazon.awssdk:rds:2.31.78")
7475
testImplementation("software.amazon.awssdk:auth:2.31.45") // Required for IAM (light implementation)
7576
testImplementation("software.amazon.awssdk:http-client-spi:2.31.60") // Required for IAM (light implementation)
77+
testImplementation("software.amazon.awssdk:dsql:2.31.78")
7678
testImplementation("software.amazon.awssdk:ec2:2.31.78")
7779
testImplementation("software.amazon.awssdk:secretsmanager:2.31.12")
7880
testImplementation("software.amazon.awssdk:sts:2.31.78")

wrapper/src/main/java/software/amazon/jdbc/ConnectionPluginChainBuilder.java

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
import software.amazon.jdbc.plugin.failover.FailoverConnectionPluginFactory;
4646
import software.amazon.jdbc.plugin.federatedauth.FederatedAuthPluginFactory;
4747
import software.amazon.jdbc.plugin.federatedauth.OktaAuthPluginFactory;
48+
import software.amazon.jdbc.plugin.iam.DsqlIamConnectionPluginFactory;
4849
import software.amazon.jdbc.plugin.iam.IamAuthConnectionPluginFactory;
4950
import software.amazon.jdbc.plugin.limitless.LimitlessConnectionPluginFactory;
5051
import software.amazon.jdbc.plugin.readwritesplitting.ReadWriteSplittingPluginFactory;
@@ -75,6 +76,7 @@ public class ConnectionPluginChainBuilder {
7576
put("failover", new FailoverConnectionPluginFactory());
7677
put("failover2", new software.amazon.jdbc.plugin.failover2.FailoverConnectionPluginFactory());
7778
put("iam", new IamAuthConnectionPluginFactory());
79+
put("dsql", new DsqlIamConnectionPluginFactory());
7880
put("awsSecretsManager", new AwsSecretsManagerConnectionPluginFactory());
7981
put("federatedAuth", new FederatedAuthPluginFactory());
8082
put("okta", new OktaAuthPluginFactory());
@@ -114,6 +116,7 @@ public class ConnectionPluginChainBuilder {
114116
put(FastestResponseStrategyPluginFactory.class, 900);
115117
put(LimitlessConnectionPluginFactory.class, 950);
116118
put(IamAuthConnectionPluginFactory.class, 1000);
119+
put(DsqlIamConnectionPluginFactory.class, 1001);
117120
put(AwsSecretsManagerConnectionPluginFactory.class, 1100);
118121
put(FederatedAuthPluginFactory.class, 1200);
119122
put(LogQueryConnectionPluginFactory.class, 1300);
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package software.amazon.jdbc.plugin.iam;
18+
19+
import java.util.Properties;
20+
import org.checkerframework.checker.nullness.qual.NonNull;
21+
import software.amazon.jdbc.ConnectionPlugin;
22+
import software.amazon.jdbc.ConnectionPluginFactory;
23+
import software.amazon.jdbc.PluginService;
24+
25+
/**
26+
* Provides {@link ConnectionPlugin} instances which can be used to connect to Amazon Aurora DSQL.
27+
*/
28+
public class DsqlIamConnectionPluginFactory implements ConnectionPluginFactory {
29+
@Override
30+
public ConnectionPlugin getInstance(@NonNull final PluginService pluginService, @NonNull final Properties props) {
31+
return new IamAuthConnectionPlugin(pluginService, new DsqlTokenUtility());
32+
}
33+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package software.amazon.jdbc.plugin.iam;
18+
19+
import org.checkerframework.checker.nullness.qual.NonNull;
20+
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
21+
import software.amazon.awssdk.regions.Region;
22+
import software.amazon.awssdk.services.dsql.DsqlUtilities;
23+
24+
/**
25+
* Represents an {@link IamTokenUtility} which provides auth tokens for connecting to Amazon Aurora DSQL.
26+
*/
27+
public class DsqlTokenUtility implements IamTokenUtility {
28+
29+
private DsqlUtilities utilities = null;
30+
31+
public DsqlTokenUtility() { }
32+
33+
// For testing only
34+
DsqlTokenUtility(@NonNull final DsqlUtilities utilities) {
35+
this.utilities = utilities;
36+
}
37+
38+
@Override
39+
public String generateAuthenticationToken(
40+
@NonNull final AwsCredentialsProvider credentialsProvider,
41+
@NonNull final Region region,
42+
@NonNull final String hostname,
43+
final int port,
44+
@NonNull final String username) {
45+
if (this.utilities == null) {
46+
this.utilities = DsqlUtilities.builder()
47+
.credentialsProvider(credentialsProvider)
48+
.region(region)
49+
.build();
50+
}
51+
if (username.equals("admin")) {
52+
return this.utilities.generateDbConnectAdminAuthToken((builder) ->
53+
builder.hostname(hostname).region(region)
54+
);
55+
} else {
56+
return this.utilities.generateDbConnectAuthToken((builder) ->
57+
builder.hostname(hostname).region(region)
58+
);
59+
}
60+
}
61+
}

wrapper/src/main/java/software/amazon/jdbc/util/RdsUtils.java

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,14 @@ public class RdsUtils {
144144
+ "\\.(amazonaws\\.com\\.?|c2s\\.ic\\.gov\\.?|sc2s\\.sgov\\.gov\\.?))$",
145145
Pattern.CASE_INSENSITIVE);
146146

147+
private static final Pattern AURORA_DSQL_CLUSTER_PATTERN =
148+
Pattern.compile(
149+
"^(?<instance>[^.]+)\\."
150+
+ "(?<dns>dsql(?:-[^.]+)?)\\."
151+
+ "(?<domain>(?<region>[a-zA-Z0-9\\-]+)"
152+
+ "\\.on\\.aws\\.?)$",
153+
Pattern.CASE_INSENSITIVE);
154+
147155
private static final Pattern ELB_PATTERN =
148156
Pattern.compile(
149157
"^(?<instance>.+)\\.elb\\."
@@ -259,14 +267,25 @@ public String getRdsInstanceHostPattern(final String host) {
259267
return group == null ? "?" : "?." + group;
260268
}
261269

270+
public String getDsqlInstanceId(final String host) {
271+
final String preparedHost = getPreparedHost(host);
272+
if (StringUtils.isNullOrEmpty(preparedHost)) {
273+
return null;
274+
}
275+
276+
final Matcher matcher = cacheMatcher(preparedHost, AURORA_DSQL_CLUSTER_PATTERN);
277+
return getRegexGroup(matcher, INSTANCE_GROUP);
278+
}
279+
262280
public String getRdsRegion(final String host) {
263281
final String preparedHost = getPreparedHost(host);
264282
if (StringUtils.isNullOrEmpty(preparedHost)) {
265283
return null;
266284
}
267285

268286
final Matcher matcher = cacheMatcher(preparedHost,
269-
AURORA_DNS_PATTERN, AURORA_CHINA_DNS_PATTERN, AURORA_OLD_CHINA_DNS_PATTERN, AURORA_GOV_DNS_PATTERN);
287+
AURORA_DNS_PATTERN, AURORA_CHINA_DNS_PATTERN, AURORA_OLD_CHINA_DNS_PATTERN, AURORA_GOV_DNS_PATTERN,
288+
AURORA_DSQL_CLUSTER_PATTERN);
270289
final String group = getRegexGroup(matcher, REGION_GROUP);
271290
if (group != null) {
272291
return group;
@@ -294,6 +313,11 @@ public boolean isLimitlessDbShardGroupDns(final String host) {
294313
return dnsGroup != null && dnsGroup.equalsIgnoreCase("shardgrp-");
295314
}
296315

316+
public boolean isDsqlCluster(final String host) {
317+
final String instanceId = getDsqlInstanceId(host);
318+
return instanceId != null;
319+
}
320+
297321
public String getRdsClusterHostUrl(final String host) {
298322
final String preparedHost = getPreparedHost(host);
299323
if (StringUtils.isNullOrEmpty(preparedHost)) {
Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
/*
2+
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License").
5+
* You may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package software.amazon.jdbc.plugin.iam;
18+
19+
import static org.junit.jupiter.api.Assertions.assertEquals;
20+
import static org.junit.jupiter.api.Assertions.assertInstanceOf;
21+
import static org.mockito.ArgumentMatchers.anyString;
22+
import static org.mockito.ArgumentMatchers.eq;
23+
import static org.mockito.Mockito.doReturn;
24+
import static software.amazon.jdbc.plugin.iam.DsqlTokenUtilityTest.ADMIN_USER;
25+
import static software.amazon.jdbc.plugin.iam.DsqlTokenUtilityTest.REGULAR_USER;
26+
import static software.amazon.jdbc.plugin.iam.DsqlTokenUtilityTest.assertTokenContainsProperties;
27+
28+
import java.sql.Connection;
29+
import java.sql.SQLException;
30+
import java.util.List;
31+
import java.util.Properties;
32+
import org.junit.jupiter.api.AfterEach;
33+
import org.junit.jupiter.api.BeforeEach;
34+
import org.junit.jupiter.api.Test;
35+
import org.junit.jupiter.params.ParameterizedTest;
36+
import org.junit.jupiter.params.provider.ValueSource;
37+
import org.mockito.Mock;
38+
import org.mockito.Mockito;
39+
import org.mockito.MockitoAnnotations;
40+
import software.amazon.awssdk.regions.Region;
41+
import software.amazon.jdbc.ConnectionPlugin;
42+
import software.amazon.jdbc.ConnectionPluginChainBuilder;
43+
import software.amazon.jdbc.ConnectionProvider;
44+
import software.amazon.jdbc.HostSpec;
45+
import software.amazon.jdbc.HostSpecBuilder;
46+
import software.amazon.jdbc.JdbcCallable;
47+
import software.amazon.jdbc.PluginManagerService;
48+
import software.amazon.jdbc.PluginService;
49+
import software.amazon.jdbc.PropertyDefinition;
50+
import software.amazon.jdbc.dialect.Dialect;
51+
import software.amazon.jdbc.hostavailability.SimpleHostAvailabilityStrategy;
52+
import software.amazon.jdbc.plugin.TokenInfo;
53+
import software.amazon.jdbc.util.FullServicesContainer;
54+
import software.amazon.jdbc.util.IamAuthUtils;
55+
import software.amazon.jdbc.util.telemetry.TelemetryContext;
56+
import software.amazon.jdbc.util.telemetry.TelemetryFactory;
57+
import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel;
58+
59+
class DsqlIamConnectionPluginTest {
60+
61+
private static final Region TEST_REGION = Region.US_EAST_1;
62+
private static final String TEST_HOSTNAME = String.format("foo0bar1baz2quux3quuux4.dsql.%s.on.aws", TEST_REGION);
63+
private static final int TEST_PORT = 5432;
64+
65+
private static final String DRIVER_PROTOCOL = "jdbc:postgresql:";
66+
private static final HostSpec HOST_SPEC = new HostSpecBuilder(new SimpleHostAvailabilityStrategy())
67+
.host(TEST_HOSTNAME).port(TEST_PORT).build();
68+
69+
private static final String DEFAULT_USERNAME = "admin";
70+
71+
private final Properties props = new Properties();
72+
73+
private AutoCloseable cleanMocksCallback;
74+
@Mock private Connection mockConnection;
75+
@Mock private PluginService mockPluginService;
76+
@Mock private FullServicesContainer mockServicesContainer;
77+
@Mock private Dialect mockDialect;
78+
@Mock private TelemetryFactory mockTelemetryFactory;
79+
@Mock private TelemetryContext mockTelemetryContext;
80+
@Mock private JdbcCallable<Connection, SQLException> mockLambda;
81+
@Mock private ConnectionProvider mockConnectionProvider;
82+
@Mock private PluginManagerService mockPluginManagerService;
83+
84+
@BeforeEach
85+
public void init() {
86+
cleanMocksCallback = MockitoAnnotations.openMocks(this);
87+
88+
IamAuthConnectionPlugin.clearCache();
89+
90+
props.setProperty(PropertyDefinition.USER.name, DEFAULT_USERNAME);
91+
props.setProperty("iamRegion", Region.US_EAST_1.toString());
92+
props.setProperty(PropertyDefinition.PLUGINS.name, "dsql");
93+
94+
doReturn(mockPluginService).when(mockServicesContainer).getPluginService();
95+
doReturn(mockDialect).when(mockPluginService).getDialect();
96+
doReturn(TEST_PORT).when(mockDialect).getDefaultPort();
97+
doReturn(mockTelemetryFactory).when(mockPluginService).getTelemetryFactory();
98+
doReturn(mockTelemetryContext).when(mockTelemetryFactory)
99+
.openTelemetryContext(anyString(), eq(TelemetryTraceLevel.NESTED));
100+
}
101+
102+
@AfterEach
103+
public void cleanup() throws Exception {
104+
cleanMocksCallback.close();
105+
}
106+
107+
@SuppressWarnings("resource") // Prevent Mockito warning when mocking closeable return type.
108+
private void assertPluginProvidesDsqlTokens(final ConnectionPlugin plugin, final String username) throws SQLException {
109+
Mockito.doReturn(mockConnection).when(mockLambda).call();
110+
111+
plugin
112+
.connect(DRIVER_PROTOCOL, HOST_SPEC, props, true, mockLambda)
113+
.close();
114+
115+
final String cacheKey = IamAuthUtils.getCacheKey(
116+
username,
117+
TEST_HOSTNAME,
118+
TEST_PORT,
119+
TEST_REGION);
120+
121+
final TokenInfo info = IamAuthCacheHolder.tokenCache.get(cacheKey);
122+
final String token = info.getToken();
123+
124+
assertTokenContainsProperties(token, TEST_HOSTNAME, username);
125+
}
126+
127+
@Test
128+
public void testDsqlPluginRegistration() throws SQLException {
129+
ConnectionPluginChainBuilder builder = new ConnectionPluginChainBuilder();
130+
131+
final List<ConnectionPlugin> result = builder.getPlugins(
132+
mockServicesContainer,
133+
mockConnectionProvider,
134+
null,
135+
mockPluginManagerService,
136+
props,
137+
null);
138+
139+
// 2 because default plugin is always included.
140+
assertEquals(2, result.size());
141+
final ConnectionPlugin plugin = result.get(0);
142+
143+
assertInstanceOf(IamAuthConnectionPlugin.class, plugin);
144+
assertPluginProvidesDsqlTokens(plugin, DEFAULT_USERNAME);
145+
}
146+
147+
@ParameterizedTest
148+
@ValueSource(strings = {REGULAR_USER, ADMIN_USER})
149+
public void testDsqlTokenGeneratedBasedOnUser(final String username) throws SQLException {
150+
props.setProperty(PropertyDefinition.USER.name, username);
151+
152+
final DsqlIamConnectionPluginFactory factory = new DsqlIamConnectionPluginFactory();
153+
final ConnectionPlugin plugin = factory.getInstance(mockPluginService, props);
154+
assertPluginProvidesDsqlTokens(plugin, username);
155+
}
156+
}

0 commit comments

Comments
 (0)