Skip to content

Commit 2fbedf9

Browse files
committed
oauth router
Signed-off-by: Veeresh K <[email protected]>
1 parent f33da1b commit 2fbedf9

File tree

3 files changed

+32
-12
lines changed

3 files changed

+32
-12
lines changed

mcpgateway/routers/oauth_router.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,6 @@ async def initiate_oauth_flow(gateway_id: str, _request: Request, db: Session =
8181
async def oauth_callback(
8282
code: str = Query(..., description="Authorization code from OAuth provider"),
8383
state: str = Query(..., description="State parameter for CSRF protection"),
84-
_request: Request = None,
8584
db: Session = Depends(get_db),
8685
) -> HTMLResponse:
8786
"""Handle the OAuth callback and complete the authorization process.
@@ -93,7 +92,6 @@ async def oauth_callback(
9392
Args:
9493
code (str): The authorization code returned by the OAuth provider.
9594
state (str): The state parameter for CSRF protection, which encodes the gateway ID.
96-
_request (Request): The incoming HTTP request object.
9795
db (Session): The database session dependency.
9896
9997
Returns:

mcpgateway/utils/url_utils.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
"""URL utilities for MCP Gateway.
2+
3+
Provides functions for handling URL protocol detection and manipulation,
4+
especially for proxy environments with forwarded headers.
5+
"""
6+
17
# Standard
28
from urllib.parse import urlparse, urlunparse
39

@@ -6,16 +12,15 @@
612

713

814
def get_protocol_from_request(request: Request) -> str:
9-
"""
10-
Return "https" or "http" based on:
11-
1) X-Forwarded-Proto (if set by a proxy)
12-
2) request.url.scheme (e.g. when Gunicorn/Uvicorn is terminating TLS)
15+
"""Get protocol from request headers or URL scheme.
16+
17+
Checks X-Forwarded-Proto header first, then falls back to request.url.scheme.
1318
1419
Args:
15-
request (Request): The FastAPI request object.
20+
request: The FastAPI request object
1621
1722
Returns:
18-
str: The protocol used for the request, either "http" or "https".
23+
Protocol string: "http" or "https"
1924
"""
2025
forwarded = request.headers.get("x-forwarded-proto")
2126
if forwarded:
@@ -26,14 +31,13 @@ def get_protocol_from_request(request: Request) -> str:
2631

2732

2833
def update_url_protocol(request: Request) -> str:
29-
"""
30-
Update the base URL protocol based on the request's scheme or forwarded headers.
34+
"""Update base URL protocol based on request headers.
3135
3236
Args:
33-
request (Request): The FastAPI request object.
37+
request: The FastAPI request object
3438
3539
Returns:
36-
str: The base URL with the correct protocol.
40+
Base URL with correct protocol
3741
"""
3842
parsed = urlparse(str(request.base_url))
3943
proto = get_protocol_from_request(request)
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
import pytest
2+
from unittest.mock import Mock
3+
from mcpgateway.utils.url_utils import get_protocol_from_request
4+
5+
6+
@pytest.mark.parametrize("headers, expected",
7+
[({"x-forwarded-proto": "http"}, "http"), # case with header
8+
({}, "https"), # fallback to request.url.scheme
9+
],
10+
)
11+
def test_get_protocol_from_request(headers, expected):
12+
"""Test get_protocol_from_request with and without x-forwarded-proto header."""
13+
mock_request = Mock()
14+
mock_request.headers = headers
15+
mock_request.url.scheme = "https"
16+
17+
result = get_protocol_from_request(mock_request)
18+
assert result == expected

0 commit comments

Comments
 (0)