Skip to content

Commit

Permalink
refactor: general cleanup of codebase (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
yezz123 authored Feb 14, 2022
1 parent bba3161 commit bf71495
Show file tree
Hide file tree
Showing 14 changed files with 88 additions and 95 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "xpresso"
version = "0.17.0"
version = "0.17.1"
description = "A developer centric, performant Python web framework"
authors = ["Adrian Garcia Badaracco <[email protected]>"]
readme = "README.md"
Expand Down
5 changes: 1 addition & 4 deletions xpresso/_utils/media_type_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,7 @@ def matches(self, media_type: typing.Optional[str]) -> bool:
if media_type is None:
return False
media_type = next(iter(media_type.split(";"))).lower()
for accepted in self.accepted:
if accepted.match(media_type):
return True
return False
return any(accepted.match(media_type) for accepted in self.accepted)

def validate(
self,
Expand Down
9 changes: 5 additions & 4 deletions xpresso/applications.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ async def __call__(
send: starlette.types.Send,
) -> None:
scope_type = scope["type"]
if scope_type == "http" or scope_type == "websocket":
if scope_type in ["http", "websocket"]:
if not self._setup_run:
self._setup()
extensions = scope.get("extensions", None) or {}
Expand Down Expand Up @@ -249,8 +249,9 @@ def _get_doc_routes(
async def openapi(req: Request) -> JSONResponse:
if self._openapi is None:
self._openapi = self.get_openapi()
res = JSONResponse(self._openapi.dict(exclude_none=True, by_alias=True))
return res
return JSONResponse(
self._openapi.dict(exclude_none=True, by_alias=True)
)

routes.append(
StarletteRoute(
Expand All @@ -266,7 +267,7 @@ async def swagger_ui_html(req: Request) -> HTMLResponse:
full_openapi_url = root_path + openapi_url # type: ignore[operator]
return get_swagger_ui_html(
openapi_url=full_openapi_url,
title=self._openapi_info.title + " - Swagger UI",
title=f"{self._openapi_info.title} - Swagger UI",
oauth2_redirect_url=None,
init_oauth=None,
)
Expand Down
12 changes: 7 additions & 5 deletions xpresso/binders/_body/extractors/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ def register_parameter(self, param: inspect.Parameter) -> BodyExtractor:

field_marker: typing.Optional[BodyExtractorMarker] = None
for marker in get_markers_from_parameter(param):
if isinstance(marker, BodyBinderMarker):
if marker.extractor_marker is not self:
# the outermost marker must be the field marker (us)
# so the first one that isn't us is the inner marker
field_marker = marker.extractor_marker
if (
isinstance(marker, BodyBinderMarker)
and marker.extractor_marker is not self
):
# the outermost marker must be the field marker (us)
# so the first one that isn't us is the inner marker
field_marker = marker.extractor_marker
if field_marker is None:
raise TypeError(
"No field marker found"
Expand Down
6 changes: 3 additions & 3 deletions xpresso/binders/_body/extractors/form_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,15 +30,15 @@ async def extract_from_form(
) -> typing.Optional[Some[typing.Any]]:
try:
return self.extractor(name=self.name, params=form.multi_items())
except InvalidSerialization:
except InvalidSerialization as e:
raise RequestValidationError(
[
ErrorWrapper(
exc=TypeError("Data is not a valid URL encoded form"),
loc=tuple((*loc, self.name)),
)
]
)
) from e
except UnexpectedFileReceived as exc:
raise RequestValidationError(
[
Expand All @@ -47,7 +47,7 @@ async def extract_from_form(
loc=tuple((*loc, self.name)),
)
]
)
) from exc


@dataclass(frozen=True)
Expand Down
4 changes: 2 additions & 2 deletions xpresso/binders/_body/extractors/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,15 @@ def _decode(
) -> typing.Union[bytes, UploadFile]:
try:
decoded = self.decoder(value)
except Exception:
except Exception as e:
raise RequestValidationError(
[
ErrorWrapper(
exc=TypeError("Data is not valid JSON"),
loc=tuple(loc),
)
]
)
) from e
return decoded


Expand Down
7 changes: 2 additions & 5 deletions xpresso/binders/_body/openapi/discriminated.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@ class OpenAPIContentTypeDiscriminatedMarker(OpenAPIBodyMarker):

def register_parameter(self, param: inspect.Parameter) -> OpenAPIBody:
field = model_field_from_param(param)
if field.required is False:
required = False
else:
required = True

required = field.required is not False
sub_body_providers: typing.Dict[str, OpenAPIBody] = {}

annotation = param.annotation
Expand All @@ -76,6 +72,7 @@ def register_parameter(self, param: inspect.Parameter) -> OpenAPIBody:
if isinstance(param_marker, BodyBinderMarker):
marker = param_marker
break

if marker is None:
raise TypeError(f"Type annotation is missing body marker: {arg}")
sub_body_openapi = marker.openapi_marker
Expand Down
12 changes: 7 additions & 5 deletions xpresso/binders/_body/openapi/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,13 @@ class OpenAPIFieldMarkerBase(OpenAPIBodyMarker):
def register_parameter(self, param: inspect.Parameter) -> OpenAPIBody:
field_marker: typing.Optional[OpenAPIBodyMarker] = None
for marker in get_markers_from_parameter(param):
if isinstance(marker, BodyBinderMarker):
if marker.openapi_marker is not self:
# the outermost marker must be the field marker (us)
# so the first one that isn't us is the inner marker
field_marker = marker.openapi_marker
if (
isinstance(marker, BodyBinderMarker)
and marker.openapi_marker is not self
):
# the outermost marker must be the field marker (us)
# so the first one that isn't us is the inner marker
field_marker = marker.openapi_marker
if field_marker is None:
raise TypeError(
"No field marker found"
Expand Down
5 changes: 1 addition & 4 deletions xpresso/binders/_body/openapi/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,10 +77,7 @@ class OpenAPIFileMarker(OpenAPIBodyMarker):
def register_parameter(self, param: inspect.Parameter) -> OpenAPIBody:
field = model_field_from_param(param)
examples = parse_examples(self.examples) if self.examples else None
if field.required is False:
required = False
else:
required = True
required = field.required is not False
return OpenAPIFileBody(
media_type=self.media_type,
description=self.description,
Expand Down
7 changes: 2 additions & 5 deletions xpresso/binders/_body/openapi/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,7 @@ class OpenAPIFormDataMarker(OpenAPIBodyMarker):

def register_parameter(self, param: inspect.Parameter) -> OpenAPIBody:
form_data_field = model_field_from_param(param)
if form_data_field.required is False:
required = False
else:
required = True

required = form_data_field.required is not False
field_openapi_providers: typing.Dict[str, OpenAPIBody] = {}
required_fields: typing.List[str] = []
# use pydantic to get rid of outer annotated, optional, etc.
Expand All @@ -115,6 +111,7 @@ def register_parameter(self, param: inspect.Parameter) -> OpenAPIBody:
if isinstance(param_marker, BodyBinderMarker):
marker = param_marker
break

field_openapi: OpenAPIBodyMarker
if marker is None:
# use the defaults
Expand Down
5 changes: 1 addition & 4 deletions xpresso/binders/_body/openapi/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,7 @@ class OpenAPIJsonMarker(OpenAPIBodyMarker):
def register_parameter(self, param: inspect.Parameter) -> OpenAPIBody:
examples = parse_examples(self.examples) if self.examples else None
field = model_field_from_param(param)
if field.required is False:
required = False
else:
required = True
required = field.required is not False
return OpenAPIJsonBody(
description=self.description,
examples=examples,
Expand Down
14 changes: 8 additions & 6 deletions xpresso/encoders/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,10 @@ def apply_custom_encoder(
return None

def __call__(
self, obj: Any, custom_encoder: Dict[Any, Callable[[Any], Any]] = {}
self, obj: Any, custom_encoder: Optional[Dict[Any, Callable[[Any], Any]]] = None
) -> Any:
if custom_encoder is None:
custom_encoder = {}
custom_encoder = {**self.custom_encoder, **custom_encoder}
if isinstance(obj, BaseModel):
encoder = getattr(obj.__config__, "json_encoders", {})
Expand Down Expand Up @@ -113,10 +115,10 @@ def __call__(
encoded_dict[encoded_key] = encoded_value
return encoded_dict
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
encoded_list: List[Any] = []
for item in cast(Sequence[Any], obj):
encoded_list.append(self(item, custom_encoder=custom_encoder))
return encoded_list
return [
self(item, custom_encoder=custom_encoder)
for item in cast(Sequence[Any], obj)
]

custom = self.apply_custom_encoder(obj, custom_encoder=custom_encoder)
if isinstance(custom, Some):
Expand All @@ -137,5 +139,5 @@ def __call__(
data = vars(obj)
except Exception as e:
errors.append(e)
raise ValueError(errors)
raise ValueError(errors) from e
return self(data, custom_encoder=custom_encoder)
45 changes: 23 additions & 22 deletions xpresso/openapi/_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
"detail": {
"title": "Detail",
"type": "array",
"items": {"$ref": REF_PREFIX + "ValidationError"},
"items": {"$ref": f"{REF_PREFIX}ValidationError"},
}
},
}
Expand All @@ -56,14 +56,12 @@ def get_parameters(
model_name_map: ModelNameMap,
schemas: Dict[str, Any],
) -> Optional[List[models.ConcreteParameter]]:
parameters: List[models.ConcreteParameter] = []
for dependant in deps:
if dependant.openapi and dependant.openapi.include_in_schema:
parameters.append(
dependant.openapi.get_openapi(
model_name_map=model_name_map, schemas=schemas
)
)
parameters: List[models.ConcreteParameter] = [
dependant.openapi.get_openapi(model_name_map=model_name_map, schemas=schemas)
for dependant in deps
if dependant.openapi and dependant.openapi.include_in_schema
]

if parameters:
return list(sorted(parameters, key=lambda param: param.name))
return None
Expand Down Expand Up @@ -179,17 +177,18 @@ def get_operation(
if schemas:
components["schemas"] = {**components.get("schemas", {}), **schemas}
http422 = str(HTTP_422_UNPROCESSABLE_ENTITY)
if (data.get("parameters", None) or data.get("requestBody", None)) and not any(
status in data["responses"] for status in (http422, "4XX", "default")
if ((data.get("parameters", None) or data.get("requestBody", None))) and all(
status not in data["responses"] for status in (http422, "4XX", "default")
):
data["responses"][http422] = {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {"$ref": REF_PREFIX + "HTTPValidationError"}
"schema": {"$ref": f"{REF_PREFIX}HTTPValidationError"}
}
},
}

if "ValidationError" not in schemas:
components["schemas"] = components.get("schemas", None) or {}
components["schemas"].update(
Expand Down Expand Up @@ -226,17 +225,18 @@ def get_paths_items(
continue
tags.extend(path_item.tags)
responses.update(path_item.responses)
operations: Dict[str, models.Operation] = {}
for method, operation in path_item.operations.items():
if not operation.include_in_schema:
continue
operations[method.lower()] = get_operation(
operations: Dict[str, models.Operation] = {
method.lower(): get_operation(
operation,
model_name_map=model_name_map,
components=components,
tags=tags + operation.tags,
response_specs={**responses, **operation.responses},
)
for method, operation in path_item.operations.items()
if operation.include_in_schema
}

paths[visited_route.path] = models.PathItem(
description=visited_route.route.description,
summary=visited_route.route.summary,
Expand All @@ -253,11 +253,12 @@ def filter_routes(visitor: Iterable[VisitedRoute[Any]]) -> Routes:
path_item = visited_route.route
if not path_item.include_in_schema:
continue
operations: Dict[str, Operation] = {}
for method, operation in path_item.operations.items():
if not operation.include_in_schema:
continue
operations[method.lower()] = operation
operations: Dict[str, Operation] = {
method.lower(): operation
for method, operation in path_item.operations.items()
if operation.include_in_schema
}

res[visited_route.path] = (path_item, operations)
return res

Expand Down
Loading

0 comments on commit bf71495

Please sign in to comment.