Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Security classes #68

Closed
wants to merge 11 commits into from
77 changes: 77 additions & 0 deletions tests/test_security/test_security_key_header.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from typing import Any, Dict, Generator

import pytest
from pydantic import BaseModel

from xpresso import App, Depends, Path
from xpresso.security import APIKeyHeader
from xpresso.testclient import TestClient
from xpresso.typing import Annotated


class APIKey(APIKeyHeader):
name = "key"


class User(BaseModel):
username: str


def get_current_user(key: APIKey):
user = User(username=key.api_key)
return user


def read_current_user(current_user: Annotated[User, Depends(get_current_user)]):
return current_user


app = App([Path("/users/me", get=read_current_user)])


@pytest.fixture
def client() -> Generator[TestClient, None, None]:
with TestClient(app) as client:
yield client


openapi_schema: Dict[str, Any] = {
"openapi": "3.0.3",
"info": {"title": "API", "version": "0.1.0"},
"paths": {
"/users/me": {
"get": {
"responses": {
"200": {
"description": "Successful Response",
}
},
"security": [{"APIKeyHeader": []}],
}
}
},
"components": {
"securitySchemes": {
"APIKeyHeader": {"type": "apiKey", "name": "key", "in": "header"}
}
},
}


@pytest.mark.xfail
def test_openapi_schema(client: TestClient):
response = client.get("/openapi.json")
assert response.status_code == 200, response.text
assert response.json() == openapi_schema


def test_security_api_key(client: TestClient):
response = client.get("/users/me", headers={"key": "secret"})
assert response.status_code == 200, response.text
assert response.json() == {"username": "secret"}


def test_security_api_key_no_key(client: TestClient):
response = client.get("/users/me")
assert response.status_code == 401, response.text
assert response.json() == {"detail": "Not authenticated"}
4 changes: 3 additions & 1 deletion xpresso/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from xpresso.routing.pathitem import Path
from xpresso.routing.router import Router
from xpresso.routing.websockets import WebSocketRoute
from xpresso.security import SecurityModel
from xpresso.websockets import WebSocket

__all__ = (
Expand Down Expand Up @@ -69,9 +70,10 @@
"RepeatedFormField",
"FormField",
"FromFile",
"status",
"Request",
"Response",
"status",
"SecurityModel",
"ByContentType",
"WebSocketRoute",
"WebSocket",
Expand Down
Empty file.
84 changes: 84 additions & 0 deletions xpresso/binders/_security/apikey.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from __future__ import annotations

from typing import ClassVar, Optional

from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED

import xpresso.openapi.models as openapi_models
from xpresso.binders.api import NamedSecurityScheme, SecurityScheme
from xpresso.exceptions import HTTPException

UNAUTHORIZED_EXC = HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated"
)


class APIKeyBase(SecurityScheme):
api_key: str
name: ClassVar[str]
scheme_name: ClassVar[Optional[str]] = None
description: ClassVar[Optional[str]] = None
unauthorized_error: ClassVar[Optional[Exception]] = UNAUTHORIZED_EXC
in_: ClassVar[str]

__slots__ = ("api_key",)

def __init__(self, api_key: str) -> None:
self.api_key = api_key

@classmethod
def get_openapi(cls) -> NamedSecurityScheme:
scheme = openapi_models.APIKey.parse_obj(
{
"in": cls.in_,
"description": cls.description,
"name": cls.name,
}
)
return NamedSecurityScheme(
name=cls.scheme_name or cls.__name__,
scheme=scheme,
)


class APIKeyQuery(APIKeyBase):
in_ = "query"

@classmethod
async def extract(cls, conn: HTTPConnection) -> Optional[APIKeyQuery]:
api_key: Optional[str] = conn.query_params.get(cls.name)
if not api_key:
if cls.unauthorized_error:
raise cls.unauthorized_error
else:
return None
return APIKeyQuery(api_key=api_key)


class APIKeyHeader(APIKeyBase):
in_ = "header"

@classmethod
async def extract(cls, conn: HTTPConnection) -> Optional[APIKeyHeader]:
api_key: Optional[str] = conn.headers.get(cls.name)
if not api_key:
if cls.unauthorized_error:
raise cls.unauthorized_error
else:
return None
return APIKeyHeader(api_key=api_key)


class APIKeyCookie(APIKeyBase):
in_ = "cookie"

@classmethod
async def extract(cls, conn: HTTPConnection) -> Optional[APIKeyCookie]:
api_key: Optional[str] = conn.cookies.get(cls.name)
if not api_key:
if cls.unauthorized_error:
raise cls.unauthorized_error
else:
return None
return APIKeyCookie(api_key=api_key)
168 changes: 168 additions & 0 deletions xpresso/binders/_security/oauth.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from __future__ import annotations

import sys
from typing import AbstractSet, ClassVar, List, Mapping, Optional

if sys.version_info < (3, 8):
from typing_extensions import Literal
else:
from typing import Literal

from pydantic import BaseModel
from starlette.requests import HTTPConnection
from starlette.status import HTTP_401_UNAUTHORIZED

import xpresso.openapi.models as openapi_models
from xpresso.binders._security.utils import get_authorization_scheme_param
from xpresso.binders.api import NamedSecurityScheme, SecurityScheme
from xpresso.bodies import FormEncodedField
from xpresso.exceptions import HTTPException
from xpresso.typing import Annotated


class OAuth2PasswordRequestForm(BaseModel):
"""
This is a dependency class, use it like:
@app.post("/login")
def login(form_data: Annotated[OAuth2PasswordRequestForm, Depends()]):
data = form_data.parse()
print(data.username)
print(data.password)
for scope in data.scopes:
print(scope)
if data.client_id:
print(data.client_id)
if data.client_secret:
print(data.client_secret)
return data
It creates the following Form conn parameters in your endpoint:
grant_type: the OAuth2 spec says it is required and MUST be the fixed string "password".
username: username string. The OAuth2 spec requires the exact field name "username".
password: password string. The OAuth2 spec requires the exact field name "password".
scope: Optional string. Several scopes (each one a string) separated by spaces. E.g.
"items:read items:write users:read profile openid"
client_id: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
using HTTP Basic auth, as: client_id:client_secret
client_secret: optional string. OAuth2 recommends sending the client_id and client_secret (if any)
using HTTP Basic auth, as: client_id:client_secret
"""

username: str
password: str
scopes: Annotated[
List[str], FormEncodedField(style="spaceDelimited", explode=False)
]
client_id: Optional[str] = None
client_secret: Optional[str] = None
grant_type: Literal["password"]


class OAuth2(SecurityScheme):
token: str
description: ClassVar[Optional[str]] = None
scheme_name: ClassVar[Optional[str]] = None
unauthenticated_error: ClassVar[Optional[Exception]] = HTTPException(
status_code=HTTP_401_UNAUTHORIZED, detail="Not authenticated"
)

def __init__(self, token: str) -> None:
self.token = token

@classmethod
async def extract(cls, conn: HTTPConnection) -> Optional[OAuth2]: # type: ignore # for Pylance
authorization: Optional[str] = conn.headers.get("Authorization")
if not authorization:
if cls.unauthenticated_error:
raise cls.unauthenticated_error
else:
return None
return OAuth2(token=authorization)

@classmethod
def get_openapi(cls) -> NamedSecurityScheme:
scheme = openapi_models.OAuth2(
description=cls.description, flows=openapi_models.OAuthFlows()
)
return NamedSecurityScheme(
name=cls.scheme_name or cls.__name__,
scheme=scheme,
)


class OAuth2PasswordBearer(OAuth2):
token_url: str
scopes: Optional[Mapping[str, str]] = None
required_scopes: AbstractSet[str] = frozenset()
unauthenticated_error = HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)

@classmethod
def get_openapi(cls) -> NamedSecurityScheme:
scheme = openapi_models.OAuth2(
flows=openapi_models.OAuthFlows(
password=openapi_models.OAuthFlowPassword(
scopes=cls.scopes, tokenUrl=cls.token_url
)
)
)
return NamedSecurityScheme(
name=cls.scheme_name or cls.__name__,
scheme=scheme,
)

@classmethod
async def extract(cls, conn: HTTPConnection) -> Optional[OAuth2PasswordBearer]:
authorization: Optional[str] = conn.headers.get("Authorization")
if authorization:
scheme, param = get_authorization_scheme_param(authorization)
if scheme.lower == "bearer":
return OAuth2PasswordBearer(token=param)
if cls.unauthenticated_error:
raise cls.unauthenticated_error
return None


class OAuth2AuthorizationCodeBearer(OAuth2):
token_url: ClassVar[str]
authorization_url: ClassVar[str]
refresh_url: ClassVar[Optional[str]] = None
scopes: ClassVar[Optional[Mapping[str, str]]] = None
required_scopes: ClassVar[Optional[AbstractSet[str]]] = None
unauthenticated_error: ClassVar[Exception] = HTTPException(
status_code=HTTP_401_UNAUTHORIZED,
detail="Not authenticated",
headers={"WWW-Authenticate": "Bearer"},
)

@classmethod
def get_openapi(cls) -> NamedSecurityScheme:
scheme = openapi_models.OAuth2(
flows=openapi_models.OAuthFlows(
authorizationCode=openapi_models.OAuthFlowAuthorizationCode(
refreshUrl=cls.refresh_url, # type: ignore[arg-type]
scopes=cls.scopes or {},
authorizationUrl=cls.authorization_url,
tokenUrl=cls.token_url,
)
)
)
return NamedSecurityScheme(
name=cls.scheme_name or cls.__name__,
scheme=scheme,
)

@classmethod
async def extract(
cls, conn: HTTPConnection
) -> Optional[OAuth2AuthorizationCodeBearer]:
authorization: Optional[str] = conn.headers.get("Authorization")
if authorization:
scheme, param = get_authorization_scheme_param(authorization)
if scheme.lower == "bearer":
return OAuth2AuthorizationCodeBearer(token=param)
if cls.unauthenticated_error:
raise cls.unauthenticated_error
return None
8 changes: 8 additions & 0 deletions xpresso/binders/_security/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from typing import Tuple


def get_authorization_scheme_param(authorization_header_value: str) -> Tuple[str, str]:
if not authorization_header_value:
return "", ""
scheme, _, param = authorization_header_value.partition(" ")
return scheme, param
Loading