Skip to content

Commit

Permalink
fix(sts): fix sts regional endpoint injection under several cases
Browse files Browse the repository at this point in the history
  • Loading branch information
windmgc committed Jul 24, 2024
1 parent 48facd6 commit 1f6cb8f
Show file tree
Hide file tree
Showing 2 changed files with 218 additions and 7 deletions.
188 changes: 188 additions & 0 deletions spec/04-services/05-sts_spec.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
setmetatable(_G, nil)

-- -- hock request sending
-- package.loaded["resty.aws.request.execute"] = function(...)
-- return ...
-- end

local AWS = require("resty.aws")
local AWS_global_config = require("resty.aws.config").global

local config = AWS_global_config
local aws = AWS(config)

aws.config.credentials = aws:Credentials {
accessKeyId = "test_id",
secretAccessKey = "test_key",
}

-- aws.config.region = "test_region"

local test_assume_role_arn = "arn:aws:iam::123456789012:role/test-role"
local test_role_session_name = "lua-resty-aws-test-assumeRole"

describe("STS service", function()
local origin_time
setup(function()
origin_time = ngx.time
ngx.time = function () --luacheck: ignore
return 1667543171
end
end)

teardown(function ()
ngx.time = origin_time --luacheck: ignore
end)

-- before_each(function()
-- sts = aws:STS()
-- end)

-- after_each(function()

-- end)

for _, region in ipairs({"us-east-1", "us-east-2", "ap-south-1", "ca-west-1", "eu-west-2", "sa-east-1"}) do
describe("In Region #" .. region, function ()
-- before_each(function()
-- aws.config.region = region
-- end)

it("AWS_STS_REGIONAL_ENDPOINT==regional with default endpoint", function ()
local config = {
region = region,
stsRegionalEndpoints = "regional",
dry_run = true,
}

local sts = aws:STS(config)
local request = sts:assumeRole({
RoleArn = test_assume_role_arn,
RoleSessionName = test_role_session_name,
})

assert.same(sts.config.stsRegionalEndpoints, "regional")
-- Check the signing region has been injected
assert.same(region, sts.config.signingRegion)
assert.truthy(sts.config._regionalEndpointInjected)
-- Check the endpoint has been injected
assert.same(sts.config.endpoint, "https://sts." .. region .. ".amazonaws.com")
assert.not_nil(request.headers.Authorization:find(region, 1, true))
end)

describe("AWS_STS_REGIONAL_ENDPOINT==regional with non-default endpoint", function()
it("and endpoint is regional domain", function ()
local config = {
region = region,
stsRegionalEndpoints = "regional",
endpoint = "https://sts." .. region .. ".amazonaws.com",
dry_run = true,
}

local sts = aws:STS(config)
local request = sts:assumeRole({
RoleArn = test_assume_role_arn,
RoleSessionName = test_role_session_name,
})

assert.same(sts.config.stsRegionalEndpoints, "regional")
-- Check the signing region has been injected
assert.same(region, sts.config.signingRegion)
assert.truthy(sts.config._regionalEndpointInjected)
-- Check thes endpoint has not been injected twice
assert.same(sts.config.endpoint, config.endpoint)
assert.not_nil(request.headers.Authorization:find(region, 1, true))
end)

it("and endpoint is global domain", function ()
local config = {
region = region,
stsRegionalEndpoints = "regional",
endpoint = "https://sts.amazonaws.com",
dry_run = true,
}

local sts = aws:STS(config)
local request = sts:assumeRole({
RoleArn = test_assume_role_arn,
RoleSessionName = test_role_session_name,
})

assert.same(sts.config.stsRegionalEndpoints, "regional")
-- Check the signing region has been injected
assert.same(region, sts.config.signingRegion)
assert.truthy(sts.config._regionalEndpointInjected)
-- Check the endpoint has been injected
assert.same(sts.config.endpoint, "https://sts." .. region .. ".amazonaws.com")
assert.not_nil(request.headers.Authorization:find(region, 1, true))
end)

it("and endpoint is region VPC endpoint", function ()
local config = {
region = region,
stsRegionalEndpoints = "regional",
endpoint = "https://vpce-1234567-abcdefg.sts." .. region .. ".vpce.amazonaws.com",
dry_run = true,
}

local sts = aws:STS(config)
local request = sts:assumeRole({
RoleArn = test_assume_role_arn,
RoleSessionName = test_role_session_name,
})

assert.same(sts.config.stsRegionalEndpoints, "regional")
-- Check the signing region has been injected
assert.same(region, sts.config.signingRegion)
assert.truthy(sts.config._regionalEndpointInjected)
-- Check the endpoint has not been injected when endpoint is a vpc endpoint
assert.same(sts.config.endpoint, config.endpoint)
assert.not_nil(request.headers.Authorization:find(region, 1, true))
end)

it("and endpoint is AZ VPC endpoint", function ()
local config = {
region = region,
stsRegionalEndpoints = "regional",
endpoint = "https://vpce-1234567-abcdefg-" .. region .. "c" .. ".sts." .. region .. ".vpce.amazonaws.com",
dry_run = true,
}

local sts = aws:STS(config)
local request = sts:assumeRole({
RoleArn = test_assume_role_arn,
RoleSessionName = test_role_session_name,
})

assert.same(sts.config.stsRegionalEndpoints, "regional")
-- Check the signing region has been injected
assert.same(region, sts.config.signingRegion)
assert.truthy(sts.config._regionalEndpointInjected)
-- Check the endpoint has not been injected when endpoint is a vpc endpoint
assert.same(sts.config.endpoint, config.endpoint)
assert.not_nil(request.headers.Authorization:find(region, 1, true))
end)
end)

it("AWS_STS_REGIONAL_ENDPOINT==legacy with default endpoint", function ()
local config = {
region = region,
stsRegionalEndpoints = "legacy",
dry_run = true,
}

local sts = aws:STS(config)
local request = sts:assumeRole({
RoleArn = test_assume_role_arn,
RoleSessionName = test_role_session_name,
})

assert.same(sts.config.stsRegionalEndpoints, "legacy")
assert.same("us-east-1", sts.config.signingRegion)
assert.is_nil(sts.config._regionalEndpointInjected)
assert.same(sts.config.endpoint, "https://sts.amazonaws.com")
assert.not_nil(request.headers.Authorization:find("us-east-1", 1, true))
end)
end)
end
end)
37 changes: 30 additions & 7 deletions src/resty/aws/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,24 @@ do
end


local isRegionalSTSDomain do
-- from the list described in https://docs.aws.amazon.com/IAM/latest/UserGuide/id_credentials_temp_enable-regions.html
-- TODO: not sure if gov cloud also has their own endpoints so leave it for now
local stsRegionRegexes = {
[[sts\.(us|eu|ap|sa|ca|me)\-\w+\-\d+\.amazonaws\.com$]],
[[sts\.cn\-\w+\-\d+\.amazonaws\.com\.cn$]],
}

function isRegionalSTSDomain(domain)
for _, entry in ipairs(stsRegionRegexes) do
if ngx.re.match(domain, entry, "jo") then
return true
end
end

return false
end
end

-- written from scratch

Expand Down Expand Up @@ -325,14 +343,19 @@ local function generate_service_methods(service)
-- https://github.com/aws/aws-sdk-js/blob/307e82673b48577fce4389e4ce03f95064e8fe0d/lib/services/sts.js#L78-L82
assert(service.config.region, "region is required when using STS regional endpoints")

-- If the endpoint is a VPC endpoint DNS hostname then we don't need to inject the region
-- VPC endpoint DNS hostnames always contain region, see
-- https://docs.aws.amazon.com/vpc/latest/privatelink/privatelink-access-aws-services.html#interface-endpoint-dns-hostnames
if not service.config._regionalEndpointInjected and not service.config.endpoint:match(AWS_VPC_ENDPOINT_DOMAIN_PATTERN) then
local pre, post = service.config.endpoint:match(AWS_PUBLIC_DOMAIN_PATTERN)
service.config.endpoint = pre .. "." .. service.config.region .. post
service.config.signingRegion = service.config.region
if not service.config._regionalEndpointInjected then
service.config._regionalEndpointInjected = true
-- stsRegionalEndpoints is set to 'regional', so inject region into the
-- signingRegion to override global region_config_data
service.config.signingRegion = service.config.region

-- If the endpoint is a VPC endpoint DNS hostname, or a regional STS domain, then we don't need to inject the region
-- VPC endpoint DNS hostnames always contain region, see
-- https://docs.aws.amazon.com/vpc/latest/privatelink/privatelink-access-aws-services.html#interface-endpoint-dns-hostnames
if not service.config.endpoint:match(AWS_VPC_ENDPOINT_DOMAIN_PATTERN) and not isRegionalSTSDomain(service.config.endpoint) then
local pre, post = service.config.endpoint:match(AWS_PUBLIC_DOMAIN_PATTERN)
service.config.endpoint = pre .. "." .. service.config.region .. post
end
end
end

Expand Down

0 comments on commit 1f6cb8f

Please sign in to comment.