diff --git a/pkg/storegateway/bucket_stores.go b/pkg/storegateway/bucket_stores.go index b9da057ae23..dea6d8a3c4a 100644 --- a/pkg/storegateway/bucket_stores.go +++ b/pkg/storegateway/bucket_stores.go @@ -786,16 +786,10 @@ func (u *BucketStores) getTokensToRetrieve(tokens uint64, dataType store.StoreDa } func getUserIDFromGRPCContext(ctx context.Context) string { - meta, ok := metadata.FromIncomingContext(ctx) - if !ok { + values := metadata.ValueFromIncomingContext(ctx, tsdb.TenantIDExternalLabel) + if values == nil || len(values) != 1 { return "" } - - values := meta.Get(tsdb.TenantIDExternalLabel) - if len(values) != 1 { - return "" - } - return values[0] } diff --git a/pkg/util/extract_forwarded.go b/pkg/util/extract_forwarded.go index 9cacf3b3a4e..8c1aae3822f 100644 --- a/pkg/util/extract_forwarded.go +++ b/pkg/util/extract_forwarded.go @@ -24,12 +24,8 @@ func GetSourceIPsFromOutgoingCtx(ctx context.Context) string { // GetSourceIPsFromIncomingCtx extracts the source field from the GRPC context func GetSourceIPsFromIncomingCtx(ctx context.Context) string { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { - return "" - } - ipAddresses, ok := md[ipAddressesKey] - if !ok { + ipAddresses := metadata.ValueFromIncomingContext(ctx, ipAddressesKey) + if ipAddresses == nil || len(ipAddresses) != 1 { return "" } return ipAddresses[0] diff --git a/pkg/util/grpcclient/signing_handler.go b/pkg/util/grpcclient/signing_handler.go index c402c963aa5..a6f5ee2f736 100644 --- a/pkg/util/grpcclient/signing_handler.go +++ b/pkg/util/grpcclient/signing_handler.go @@ -33,20 +33,12 @@ func UnarySigningServerInterceptor(ctx context.Context, req any, _ *grpc.UnarySe return handler(ctx, req) } - md, ok := metadata.FromIncomingContext(ctx) - - if !ok { - return nil, ErrSignatureNotPresent - } - - sig, ok := md[reqSignHeaderName] - - if !ok || len(sig) != 1 { + sig := metadata.ValueFromIncomingContext(ctx, reqSignHeaderName) + if sig == nil || len(sig) != 1 { return nil, ErrSignatureNotPresent } valid, err := rs.VerifySign(ctx, sig[0]) - if err != nil { return nil, err } diff --git a/pkg/util/grpcutil/util.go b/pkg/util/grpcutil/util.go index b9e4da4afdb..acad4ba897b 100644 --- a/pkg/util/grpcutil/util.go +++ b/pkg/util/grpcutil/util.go @@ -51,11 +51,17 @@ func HTTPHeaderPropagationStreamServerInterceptor(srv any, ss grpc.ServerStream, // extractForwardedRequestMetadataFromMetadata implements HTTPHeaderPropagationServerInterceptor by placing forwarded // headers into incoming context func extractForwardedRequestMetadataFromMetadata(ctx context.Context) context.Context { - md, ok := metadata.FromIncomingContext(ctx) - if !ok { + headersSlice := metadata.ValueFromIncomingContext(ctx, requestmeta.PropagationStringForRequestMetadata) + if headersSlice == nil { + // we want to check old key if no data + headersSlice = metadata.ValueFromIncomingContext(ctx, requestmeta.HeaderPropagationStringForRequestLogging) + } + + if headersSlice == nil { return ctx } - return requestmeta.ContextWithRequestMetadataMapFromMetadata(ctx, md) + + return requestmeta.ContextWithRequestMetadataMapFromHeaderSlice(ctx, headersSlice) } // HTTPHeaderPropagationClientInterceptor allows for propagation of HTTP Request headers across gRPC calls - works diff --git a/pkg/util/grpcutil/util_test.go b/pkg/util/grpcutil/util_test.go new file mode 100644 index 00000000000..5dacff00143 --- /dev/null +++ b/pkg/util/grpcutil/util_test.go @@ -0,0 +1,84 @@ +package grpcutil + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "google.golang.org/grpc/metadata" + + "github.com/cortexproject/cortex/pkg/util/requestmeta" +) + +// TestExtractForwardedRequestMetadataFromMetadata tests the extractForwardedRequestMetadataFromMetadata function +func TestExtractForwardedRequestMetadataFromMetadata(t *testing.T) { + tests := []struct { + name string + ctx context.Context + expectedResult map[string]string + }{ + { + name: "context without metadata", + ctx: context.Background(), + expectedResult: nil, + }, + { + name: "context with new metadata key", + ctx: func() context.Context { + md := metadata.New(nil) + md.Append(requestmeta.PropagationStringForRequestMetadata, "header1", "value1", "header2", "value2") + return metadata.NewIncomingContext(context.Background(), md) + }(), + expectedResult: map[string]string{ + "header1": "value1", + "header2": "value2", + }, + }, + { + name: "context with old metadata key", + ctx: func() context.Context { + md := metadata.New(nil) + md.Append(requestmeta.HeaderPropagationStringForRequestLogging, "header1", "value1", "header2", "value2") + return metadata.NewIncomingContext(context.Background(), md) + }(), + expectedResult: map[string]string{ + "header1": "value1", + "header2": "value2", + }, + }, + { + name: "context with both keys, new key takes precedence", + ctx: func() context.Context { + md := metadata.New(nil) + md.Append(requestmeta.PropagationStringForRequestMetadata, "newheader", "newvalue") + md.Append(requestmeta.HeaderPropagationStringForRequestLogging, "oldheader", "oldvalue") + return metadata.NewIncomingContext(context.Background(), md) + }(), + expectedResult: map[string]string{ + "newheader": "newvalue", + }, + }, + { + name: "context with odd number of metadata values", + ctx: func() context.Context { + md := metadata.New(nil) + md.Append(requestmeta.PropagationStringForRequestMetadata, "header1", "value1", "header2") + return metadata.NewIncomingContext(context.Background(), md) + }(), + expectedResult: nil, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := extractForwardedRequestMetadataFromMetadata(tt.ctx) + metadataMap := requestmeta.MapFromContext(result) + + if tt.expectedResult == nil { + assert.Nil(t, metadataMap) + } else { + assert.Equal(t, tt.expectedResult, metadataMap) + } + }) + } +} diff --git a/pkg/util/requestmeta/context.go b/pkg/util/requestmeta/context.go index 43ee33c7bc4..190ec3f8061 100644 --- a/pkg/util/requestmeta/context.go +++ b/pkg/util/requestmeta/context.go @@ -55,6 +55,19 @@ func InjectMetadataIntoHTTPRequestHeaders(requestMetadataMap map[string]string, } } +func ContextWithRequestMetadataMapFromHeaderSlice(ctx context.Context, headerSlice []string) context.Context { + if len(headerSlice)%2 == 1 { + return ctx + } + + requestMetadataMap := make(map[string]string, len(headerSlice)/2) + for i := 0; i < len(headerSlice); i += 2 { + requestMetadataMap[headerSlice[i]] = headerSlice[i+1] + } + + return ContextWithRequestMetadataMap(ctx, requestMetadataMap) +} + func ContextWithRequestMetadataMapFromMetadata(ctx context.Context, md metadata.MD) context.Context { headersSlice, ok := md[PropagationStringForRequestMetadata] @@ -63,14 +76,9 @@ func ContextWithRequestMetadataMapFromMetadata(ctx context.Context, md metadata. headersSlice, ok = md[HeaderPropagationStringForRequestLogging] } - if !ok || len(headersSlice)%2 == 1 { + if !ok { return ctx } - requestMetadataMap := make(map[string]string) - for i := 0; i < len(headersSlice); i += 2 { - requestMetadataMap[headersSlice[i]] = headersSlice[i+1] - } - - return ContextWithRequestMetadataMap(ctx, requestMetadataMap) + return ContextWithRequestMetadataMapFromHeaderSlice(ctx, headersSlice) } diff --git a/pkg/util/requestmeta/context_test.go b/pkg/util/requestmeta/context_test.go index 23a0d3b4dab..795dcce38f7 100644 --- a/pkg/util/requestmeta/context_test.go +++ b/pkg/util/requestmeta/context_test.go @@ -111,3 +111,85 @@ func TestInjectMetadataIntoHTTPRequestHeaders(t *testing.T) { require.Equal(t, "ContentsOfTestHeader2", header2[0]) } + +func TestContextWithRequestMetadataMapFromHeaderSlice(t *testing.T) { + tests := []struct { + name string + headerSlice []string + expectedResult map[string]string + }{ + { + name: "empty header slice", + headerSlice: []string{}, + expectedResult: map[string]string{}, + }, + { + name: "nil header slice", + headerSlice: nil, + expectedResult: map[string]string{}, + }, + { + name: "odd number of elements", + headerSlice: []string{"header1", "value1", "header2"}, + expectedResult: nil, + }, + { + name: "single key-value pair", + headerSlice: []string{"header1", "value1"}, + expectedResult: map[string]string{ + "header1": "value1", + }, + }, + { + name: "multiple key-value pairs", + headerSlice: []string{"header1", "value1", "header2", "value2", "header3", "value3"}, + expectedResult: map[string]string{ + "header1": "value1", + "header2": "value2", + "header3": "value3", + }, + }, + { + name: "duplicate keys (last value wins)", + headerSlice: []string{"header1", "value1", "header1", "value2"}, + expectedResult: map[string]string{ + "header1": "value2", + }, + }, + { + name: "empty values", + headerSlice: []string{"header1", "", "header2", "value2"}, + expectedResult: map[string]string{ + "header1": "", + "header2": "value2", + }, + }, + { + name: "special characters in keys and values", + headerSlice: []string{"header-1", "value with spaces", "header_2", "value-with-dashes"}, + expectedResult: map[string]string{ + "header-1": "value with spaces", + "header_2": "value-with-dashes", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + result := ContextWithRequestMetadataMapFromHeaderSlice(ctx, tt.headerSlice) + metadataMap := MapFromContext(result) + + if tt.expectedResult == nil { + require.Nil(t, metadataMap) + } else { + require.NotNil(t, metadataMap) + require.Equal(t, len(tt.expectedResult), len(metadataMap)) + for key, expectedValue := range tt.expectedResult { + require.Contains(t, metadataMap, key) + require.Equal(t, expectedValue, metadataMap[key]) + } + } + }) + } +}