Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 2 additions & 8 deletions pkg/storegateway/bucket_stores.go
Original file line number Diff line number Diff line change
Expand Up @@ -786,16 +786,10 @@ func (u *BucketStores) getTokensToRetrieve(tokens uint64, dataType store.StoreDa
}

func getUserIDFromGRPCContext(ctx context.Context) string {
meta, ok := metadata.FromIncomingContext(ctx)
if !ok {
values := metadata.ValueFromIncomingContext(ctx, tsdb.TenantIDExternalLabel)
if values == nil || len(values) != 1 {
return ""
}

values := meta.Get(tsdb.TenantIDExternalLabel)
if len(values) != 1 {
return ""
}

return values[0]
}

Expand Down
8 changes: 2 additions & 6 deletions pkg/util/extract_forwarded.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,8 @@ func GetSourceIPsFromOutgoingCtx(ctx context.Context) string {

// GetSourceIPsFromIncomingCtx extracts the source field from the GRPC context
func GetSourceIPsFromIncomingCtx(ctx context.Context) string {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
return ""
}
ipAddresses, ok := md[ipAddressesKey]
if !ok {
ipAddresses := metadata.ValueFromIncomingContext(ctx, ipAddressesKey)
if ipAddresses == nil || len(ipAddresses) != 1 {
return ""
}
return ipAddresses[0]
Expand Down
12 changes: 2 additions & 10 deletions pkg/util/grpcclient/signing_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,12 @@ func UnarySigningServerInterceptor(ctx context.Context, req any, _ *grpc.UnarySe
return handler(ctx, req)
}

md, ok := metadata.FromIncomingContext(ctx)

if !ok {
return nil, ErrSignatureNotPresent
}

sig, ok := md[reqSignHeaderName]

if !ok || len(sig) != 1 {
sig := metadata.ValueFromIncomingContext(ctx, reqSignHeaderName)
if sig == nil || len(sig) != 1 {
return nil, ErrSignatureNotPresent
}

valid, err := rs.VerifySign(ctx, sig[0])

if err != nil {
return nil, err
}
Expand Down
12 changes: 9 additions & 3 deletions pkg/util/grpcutil/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,17 @@ func HTTPHeaderPropagationStreamServerInterceptor(srv any, ss grpc.ServerStream,
// extractForwardedRequestMetadataFromMetadata implements HTTPHeaderPropagationServerInterceptor by placing forwarded
// headers into incoming context
func extractForwardedRequestMetadataFromMetadata(ctx context.Context) context.Context {
md, ok := metadata.FromIncomingContext(ctx)
if !ok {
headersSlice := metadata.ValueFromIncomingContext(ctx, requestmeta.PropagationStringForRequestMetadata)
if headersSlice == nil {
// we want to check old key if no data
headersSlice = metadata.ValueFromIncomingContext(ctx, requestmeta.HeaderPropagationStringForRequestLogging)
}

if headersSlice == nil {
return ctx
}
return requestmeta.ContextWithRequestMetadataMapFromMetadata(ctx, md)

return requestmeta.ContextWithRequestMetadataMapFromHeaderSlice(ctx, headersSlice)
}

// HTTPHeaderPropagationClientInterceptor allows for propagation of HTTP Request headers across gRPC calls - works
Expand Down
84 changes: 84 additions & 0 deletions pkg/util/grpcutil/util_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package grpcutil

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/grpc/metadata"

"github.com/cortexproject/cortex/pkg/util/requestmeta"
)

// TestExtractForwardedRequestMetadataFromMetadata tests the extractForwardedRequestMetadataFromMetadata function
func TestExtractForwardedRequestMetadataFromMetadata(t *testing.T) {
tests := []struct {
name string
ctx context.Context
expectedResult map[string]string
}{
{
name: "context without metadata",
ctx: context.Background(),
expectedResult: nil,
},
{
name: "context with new metadata key",
ctx: func() context.Context {
md := metadata.New(nil)
md.Append(requestmeta.PropagationStringForRequestMetadata, "header1", "value1", "header2", "value2")
return metadata.NewIncomingContext(context.Background(), md)
}(),
expectedResult: map[string]string{
"header1": "value1",
"header2": "value2",
},
},
{
name: "context with old metadata key",
ctx: func() context.Context {
md := metadata.New(nil)
md.Append(requestmeta.HeaderPropagationStringForRequestLogging, "header1", "value1", "header2", "value2")
return metadata.NewIncomingContext(context.Background(), md)
}(),
expectedResult: map[string]string{
"header1": "value1",
"header2": "value2",
},
},
{
name: "context with both keys, new key takes precedence",
ctx: func() context.Context {
md := metadata.New(nil)
md.Append(requestmeta.PropagationStringForRequestMetadata, "newheader", "newvalue")
md.Append(requestmeta.HeaderPropagationStringForRequestLogging, "oldheader", "oldvalue")
return metadata.NewIncomingContext(context.Background(), md)
}(),
expectedResult: map[string]string{
"newheader": "newvalue",
},
},
{
name: "context with odd number of metadata values",
ctx: func() context.Context {
md := metadata.New(nil)
md.Append(requestmeta.PropagationStringForRequestMetadata, "header1", "value1", "header2")
return metadata.NewIncomingContext(context.Background(), md)
}(),
expectedResult: nil,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := extractForwardedRequestMetadataFromMetadata(tt.ctx)
metadataMap := requestmeta.MapFromContext(result)

if tt.expectedResult == nil {
assert.Nil(t, metadataMap)
} else {
assert.Equal(t, tt.expectedResult, metadataMap)
}
})
}
}
22 changes: 15 additions & 7 deletions pkg/util/requestmeta/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,19 @@ func InjectMetadataIntoHTTPRequestHeaders(requestMetadataMap map[string]string,
}
}

func ContextWithRequestMetadataMapFromHeaderSlice(ctx context.Context, headerSlice []string) context.Context {
if len(headerSlice)%2 == 1 {
return ctx
}

requestMetadataMap := make(map[string]string, len(headerSlice)/2)
for i := 0; i < len(headerSlice); i += 2 {
requestMetadataMap[headerSlice[i]] = headerSlice[i+1]
}

return ContextWithRequestMetadataMap(ctx, requestMetadataMap)
}

func ContextWithRequestMetadataMapFromMetadata(ctx context.Context, md metadata.MD) context.Context {
headersSlice, ok := md[PropagationStringForRequestMetadata]

Expand All @@ -63,14 +76,9 @@ func ContextWithRequestMetadataMapFromMetadata(ctx context.Context, md metadata.
headersSlice, ok = md[HeaderPropagationStringForRequestLogging]
}

if !ok || len(headersSlice)%2 == 1 {
if !ok {
return ctx
}

requestMetadataMap := make(map[string]string)
for i := 0; i < len(headersSlice); i += 2 {
requestMetadataMap[headersSlice[i]] = headersSlice[i+1]
}

return ContextWithRequestMetadataMap(ctx, requestMetadataMap)
return ContextWithRequestMetadataMapFromHeaderSlice(ctx, headersSlice)
}
82 changes: 82 additions & 0 deletions pkg/util/requestmeta/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,85 @@ func TestInjectMetadataIntoHTTPRequestHeaders(t *testing.T) {
require.Equal(t, "ContentsOfTestHeader2", header2[0])

}

func TestContextWithRequestMetadataMapFromHeaderSlice(t *testing.T) {
tests := []struct {
name string
headerSlice []string
expectedResult map[string]string
}{
{
name: "empty header slice",
headerSlice: []string{},
expectedResult: map[string]string{},
},
{
name: "nil header slice",
headerSlice: nil,
expectedResult: map[string]string{},
},
{
name: "odd number of elements",
headerSlice: []string{"header1", "value1", "header2"},
expectedResult: nil,
},
{
name: "single key-value pair",
headerSlice: []string{"header1", "value1"},
expectedResult: map[string]string{
"header1": "value1",
},
},
{
name: "multiple key-value pairs",
headerSlice: []string{"header1", "value1", "header2", "value2", "header3", "value3"},
expectedResult: map[string]string{
"header1": "value1",
"header2": "value2",
"header3": "value3",
},
},
{
name: "duplicate keys (last value wins)",
headerSlice: []string{"header1", "value1", "header1", "value2"},
expectedResult: map[string]string{
"header1": "value2",
},
},
{
name: "empty values",
headerSlice: []string{"header1", "", "header2", "value2"},
expectedResult: map[string]string{
"header1": "",
"header2": "value2",
},
},
{
name: "special characters in keys and values",
headerSlice: []string{"header-1", "value with spaces", "header_2", "value-with-dashes"},
expectedResult: map[string]string{
"header-1": "value with spaces",
"header_2": "value-with-dashes",
},
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
result := ContextWithRequestMetadataMapFromHeaderSlice(ctx, tt.headerSlice)
metadataMap := MapFromContext(result)

if tt.expectedResult == nil {
require.Nil(t, metadataMap)
} else {
require.NotNil(t, metadataMap)
require.Equal(t, len(tt.expectedResult), len(metadataMap))
for key, expectedValue := range tt.expectedResult {
require.Contains(t, metadataMap, key)
require.Equal(t, expectedValue, metadataMap[key])
}
}
})
}
}
Loading