From 84837fcb28e5ac8d057c511eb19cfeeff7d299e1 Mon Sep 17 00:00:00 2001 From: cbudau Date: Mon, 3 Mar 2025 11:58:31 +0200 Subject: [PATCH 1/2] feat(credential-providers): cache credentials across clients Avoid redundant credential fetches by caching credentials globally across AWS SDK v3 clients. --- .../src/defaultProvider.ts | 140 ++++++++++-------- 1 file changed, 80 insertions(+), 60 deletions(-) diff --git a/packages/credential-provider-node/src/defaultProvider.ts b/packages/credential-provider-node/src/defaultProvider.ts index b16112434ec6c..a25fe2a5459b9 100644 --- a/packages/credential-provider-node/src/defaultProvider.ts +++ b/packages/credential-provider-node/src/defaultProvider.ts @@ -26,6 +26,11 @@ export type DefaultProviderInit = FromIniInit & */ let multipleCredentialSourceWarningEmitted = false; +/** + * @internal + */ +const credentialCache = new WeakMap<() => Promise, AwsCredentialIdentity>(); + /** * Creates a credential provider that will attempt to find credentials from the * following sources (listed in order of precedence): @@ -60,19 +65,18 @@ let multipleCredentialSourceWarningEmitted = false; * @see {@link fromContainerMetadata} The function used to source credentials from the * ECS Container Metadata Service. */ -export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvider => - memoize( - chain( - async () => { - const profile = init.profile ?? process.env[ENV_PROFILE]; - if (profile) { - const envStaticCredentialsAreSet = process.env[ENV_KEY] && process.env[ENV_SECRET]; - if (envStaticCredentialsAreSet) { - if (!multipleCredentialSourceWarningEmitted) { - const warnFn = - init.logger?.warn && init.logger?.constructor?.name !== "NoOpLogger" ? init.logger.warn : console.warn; - warnFn( - `@aws-sdk/credential-provider-node - defaultProvider::fromEnv WARNING: +export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvider => { + const providerChain = chain( + async () => { + const profile = init.profile ?? process.env[ENV_PROFILE]; + if (profile) { + const envStaticCredentialsAreSet = process.env[ENV_KEY] && process.env[ENV_SECRET]; + if (envStaticCredentialsAreSet) { + if (!multipleCredentialSourceWarningEmitted) { + const warnFn = + init.logger?.warn && init.logger?.constructor?.name !== "NoOpLogger" ? init.logger.warn : console.warn; + warnFn( + `@aws-sdk/credential-provider-node - defaultProvider::fromEnv WARNING: Multiple credential sources detected: Both AWS_PROFILE and the pair AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY static credentials are set. This SDK will proceed with the AWS_PROFILE value. @@ -81,59 +85,75 @@ export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvide Please ensure that your environment only sets either the AWS_PROFILE or the AWS_ACCESS_KEY_ID/AWS_SECRET_ACCESS_KEY pair. ` - ); - multipleCredentialSourceWarningEmitted = true; - } + ); + multipleCredentialSourceWarningEmitted = true; } - throw new CredentialsProviderError("AWS_PROFILE is set, skipping fromEnv provider.", { - logger: init.logger, - tryNextLink: true, - }); - } - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromEnv"); - return fromEnv(init)(); - }, - async () => { - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromSSO"); - const { ssoStartUrl, ssoAccountId, ssoRegion, ssoRoleName, ssoSession } = init; - if (!ssoStartUrl && !ssoAccountId && !ssoRegion && !ssoRoleName && !ssoSession) { - throw new CredentialsProviderError( - "Skipping SSO provider in default chain (inputs do not include SSO fields).", - { logger: init.logger } - ); } - const { fromSSO } = await import("@aws-sdk/credential-provider-sso"); - return fromSSO(init)(); - }, - async () => { - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromIni"); - const { fromIni } = await import("@aws-sdk/credential-provider-ini"); - return fromIni(init)(); - }, - async () => { - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromProcess"); - const { fromProcess } = await import("@aws-sdk/credential-provider-process"); - return fromProcess(init)(); - }, - async () => { - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromTokenFile"); - const { fromTokenFile } = await import("@aws-sdk/credential-provider-web-identity"); - return fromTokenFile(init)(); - }, - async () => { - init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::remoteProvider"); - return (await remoteProvider(init))(); - }, - async () => { - throw new CredentialsProviderError("Could not load credentials from any providers", { - tryNextLink: false, + throw new CredentialsProviderError("AWS_PROFILE is set, skipping fromEnv provider.", { logger: init.logger, + tryNextLink: true, }); } - ), - credentialsTreatedAsExpired, - credentialsWillNeedRefresh + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromEnv"); + return fromEnv(init)(); + }, + async () => { + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromSSO"); + const { ssoStartUrl, ssoAccountId, ssoRegion, ssoRoleName, ssoSession } = init; + if (!ssoStartUrl && !ssoAccountId && !ssoRegion && !ssoRoleName && !ssoSession) { + throw new CredentialsProviderError( + "Skipping SSO provider in default chain (inputs do not include SSO fields).", + { logger: init.logger } + ); + } + const { fromSSO } = await import("@aws-sdk/credential-provider-sso"); + return fromSSO(init)(); + }, + async () => { + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromIni"); + const { fromIni } = await import("@aws-sdk/credential-provider-ini"); + return fromIni(init)(); + }, + async () => { + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromProcess"); + const { fromProcess } = await import("@aws-sdk/credential-provider-process"); + return fromProcess(init)(); + }, + async () => { + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::fromTokenFile"); + const { fromTokenFile } = await import("@aws-sdk/credential-provider-web-identity"); + return fromTokenFile(init)(); + }, + async () => { + init.logger?.debug("@aws-sdk/credential-provider-node - defaultProvider::remoteProvider"); + return (await remoteProvider(init))(); + }, + async () => { + throw new CredentialsProviderError("Could not load credentials from any providers", { + tryNextLink: false, + logger: init.logger, + }); + } + ); + + return memoize( + providerChain, + (credentials) => { + const cached = credentialCache.get(providerChain); + if (cached && !credentialsTreatedAsExpired(cached)) { + return true; + } + return credentialsTreatedAsExpired(credentials); + }, + (credentials) => { + const needsRefresh = credentialsWillNeedRefresh(credentials); + if (!needsRefresh) { + credentialCache.set(providerChain, credentials); + } + return needsRefresh; + } ); +}; /** * @internal From 38f3de57aa938fc1b282a457d2cea2710bcd7004 Mon Sep 17 00:00:00 2001 From: cbudau Date: Fri, 21 Mar 2025 20:27:41 +0200 Subject: [PATCH 2/2] feat(credential-providers): cache credentials across clients Updated global memoize and integration tests --- .../credential-provider-node.integ.spec.ts | 88 +++++++++++++++++++ .../src/defaultProvider.ts | 33 +++---- .../src/memoizeGlobal.ts | 44 ++++++++++ 3 files changed, 145 insertions(+), 20 deletions(-) create mode 100644 packages/credential-provider-node/src/memoizeGlobal.ts diff --git a/packages/credential-provider-node/src/credential-provider-node.integ.spec.ts b/packages/credential-provider-node/src/credential-provider-node.integ.spec.ts index 4d65fcb757cb4..6a7886b175c4f 100644 --- a/packages/credential-provider-node/src/credential-provider-node.integ.spec.ts +++ b/packages/credential-provider-node/src/credential-provider-node.integ.spec.ts @@ -8,6 +8,7 @@ import { AdaptiveRetryStrategy, StandardRetryStrategy } from "@smithy/util-retry import { PassThrough } from "stream"; import { defaultProvider } from "./defaultProvider"; +import { clearDefaultProviderCache } from "./memoizeGlobal"; jest.mock("fs", () => { const actual = jest.requireActual("fs"); @@ -1273,4 +1274,91 @@ describe("credential-provider-node integration test", () => { expect(async () => sts.getCallerIdentity({})).rejects.toThrow("Could not load credentials from any providers"); }); }); + + describe("Global Cache Behavior", () => { + beforeEach(() => { + clearDefaultProviderCache(); + jest.clearAllMocks(); + for (const variable in RESERVED_ENVIRONMENT_VARIABLES) { + delete process.env[variable]; + } + }); + + afterEach(() => { + clearDefaultProviderCache(); + }); + + it("should cache credentials across provider instances", async () => { + // Set up environment credentials to avoid profile warning + process.env.AWS_ACCESS_KEY_ID = "AKID"; + process.env.AWS_SECRET_ACCESS_KEY = "SECRET"; + + const provider1 = defaultProvider(); + const provider2 = defaultProvider(); + + const creds1 = await provider1(); + const creds2 = await provider2(); + + expect(creds1).toEqual(creds2); + expect(creds1).toEqual({ + accessKeyId: "AKID", + secretAccessKey: "SECRET", + $source: { + CREDENTIALS_ENV_VARS: "g", + }, + }); + }); + + it("should maintain separate caches for different profiles", async () => { + // Clear env variables to allow profile credentials + delete process.env.AWS_ACCESS_KEY_ID; + delete process.env.AWS_SECRET_ACCESS_KEY; + + Object.assign(iniProfileData, { + profile1: { + aws_access_key_id: "AKID1", + aws_secret_access_key: "SECRET1", + }, + profile2: { + aws_access_key_id: "AKID2", + aws_secret_access_key: "SECRET2", + }, + }); + + const provider1 = defaultProvider({ profile: "profile1" }); + const provider2 = defaultProvider({ profile: "profile2" }); + + const [creds1, creds2] = await Promise.all([provider1(), provider2()]); + + expect(creds1.accessKeyId).toBe("AKID1"); + expect(creds2.accessKeyId).toBe("AKID2"); + expect(creds1).not.toEqual(creds2); + }); + + it("should handle expired credentials", async () => { + process.env.AWS_ACCESS_KEY_ID = "AKID"; + process.env.AWS_SECRET_ACCESS_KEY = "SECRET"; + + const provider = defaultProvider(); + const creds = await provider(); + + // Simulate expiration + Object.defineProperty(creds, "expiration", { + value: new Date(Date.now() - 300001), // Just over 5 minutes ago + }); + + // Should force a refresh on next call + const newCreds = await provider(); + expect(newCreds).toBeDefined(); + expect(newCreds.accessKeyId).toBe("AKID"); + }); + + it("should handle provider errors", async () => { + delete process.env.AWS_ACCESS_KEY_ID; + delete process.env.AWS_SECRET_ACCESS_KEY; + + const provider = defaultProvider(); + await expect(provider()).rejects.toThrow("Could not load credentials from any providers"); + }); + }); }); diff --git a/packages/credential-provider-node/src/defaultProvider.ts b/packages/credential-provider-node/src/defaultProvider.ts index a25fe2a5459b9..4c02a17afa992 100644 --- a/packages/credential-provider-node/src/defaultProvider.ts +++ b/packages/credential-provider-node/src/defaultProvider.ts @@ -9,6 +9,7 @@ import { chain, CredentialsProviderError, memoize } from "@smithy/property-provi import { ENV_PROFILE } from "@smithy/shared-ini-file-loader"; import { AwsCredentialIdentity, MemoizedProvider } from "@smithy/types"; +import { memoizeGlobal } from "./memoizeGlobal"; import { remoteProvider } from "./remoteProvider"; /** @@ -26,11 +27,6 @@ export type DefaultProviderInit = FromIniInit & */ let multipleCredentialSourceWarningEmitted = false; -/** - * @internal - */ -const credentialCache = new WeakMap<() => Promise, AwsCredentialIdentity>(); - /** * Creates a credential provider that will attempt to find credentials from the * following sources (listed in order of precedence): @@ -66,7 +62,7 @@ const credentialCache = new WeakMap<() => Promise, AwsCre * ECS Container Metadata Service. */ export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvider => { - const providerChain = chain( + const provider = chain( async () => { const profile = init.profile ?? process.env[ENV_PROFILE]; if (profile) { @@ -136,22 +132,19 @@ export const defaultProvider = (init: DefaultProviderInit = {}): MemoizedProvide } ); - return memoize( - providerChain, - (credentials) => { - const cached = credentialCache.get(providerChain); - if (cached && !credentialsTreatedAsExpired(cached)) { - return true; + return memoizeGlobal( + async () => { + try { + return await provider(); + } catch (error) { + if (error instanceof CredentialsProviderError) { + throw error; + } + throw new CredentialsProviderError(error.message, { tryNextLink: true }); } - return credentialsTreatedAsExpired(credentials); }, - (credentials) => { - const needsRefresh = credentialsWillNeedRefresh(credentials); - if (!needsRefresh) { - credentialCache.set(providerChain, credentials); - } - return needsRefresh; - } + credentialsTreatedAsExpired, + credentialsWillNeedRefresh ); }; diff --git a/packages/credential-provider-node/src/memoizeGlobal.ts b/packages/credential-provider-node/src/memoizeGlobal.ts new file mode 100644 index 0000000000000..028e6d6b9a9ac --- /dev/null +++ b/packages/credential-provider-node/src/memoizeGlobal.ts @@ -0,0 +1,44 @@ +import { memoize } from "@smithy/property-provider"; +import { AwsCredentialIdentity } from "@smithy/types"; + +const globalProviderCache: Map Promise> = new Map(); + +function hashProvider(provider: () => Promise, config?: string): string { + return config || provider.name || Math.random().toString(36).substring(7); +} + +export function memoizeGlobal( + provider: () => Promise, + isExpired: (resolved: T) => boolean, + requiresRefresh?: (resolved: T) => boolean, + cacheKey?: string +): () => Promise { + const key = hashProvider(provider, cacheKey); + const cached = globalProviderCache.get(key); + if (cached) { + return cached as () => Promise; + } + + const memoized = memoize(provider, isExpired, requiresRefresh); + const wrappedProvider = async () => { + try { + const creds = await memoized(); + if (isExpired(creds)) { + globalProviderCache.delete(key); + // Force memoize to refresh by calling provider directly + return await provider(); + } + return creds; + } catch (error) { + globalProviderCache.delete(key); + throw error; + } + }; + + globalProviderCache.set(key, wrappedProvider); + return wrappedProvider; +} + +export function clearDefaultProviderCache(): void { + globalProviderCache.clear(); +}