Skip to content

Commit

Permalink
Better error handling for invalid children.
Browse files Browse the repository at this point in the history
Validate children directly when possible.
  • Loading branch information
pelme committed Sep 12, 2024
1 parent e61d2db commit 909ecb8
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 6 deletions.
4 changes: 4 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# Changelog

## next
- Raise errors directly on invalid children. This avoids cryptic stack traces.
[PR #56](https://github.com/pelme/htpy/pull/56).

## 24.9.1 - 2024-09-09
- Raise errors directly on invalid attributes. This avoids cryptic stack traces
for invalid attributes. [Issue #49](https://github.com/pelme/htpy/issues/49)
Expand Down
37 changes: 34 additions & 3 deletions htpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@
import dataclasses
import functools
import typing as t
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Callable, Generator, Iterable, Iterator

from markupsafe import Markup as _Markup
from markupsafe import escape as _escape

if t.TYPE_CHECKING:
from types import UnionType

BaseElementSelf = t.TypeVar("BaseElementSelf", bound="BaseElement")
ElementSelf = t.TypeVar("ElementSelf", bound="Element")

Expand Down Expand Up @@ -190,7 +193,7 @@ def _iter_node_context(x: Node, context_dict: dict[Context[t.Any], t.Any]) -> It
yield str(_escape(x))
elif isinstance(x, int):
yield str(x)
elif isinstance(x, Iterable) and not isinstance(x, bytes): # pyright: ignore [reportUnnecessaryIsInstance]
elif isinstance(x, Iterable) and not isinstance(x, _KnownInvalidChildren): # pyright: ignore [reportUnnecessaryIsInstance]
for child in x:
yield from _iter_node_context(child, context_dict)
else:
Expand Down Expand Up @@ -288,9 +291,22 @@ def encode(self, encoding: str = "utf-8", errors: str = "strict") -> bytes:
do_not_call_in_templates = True


def _validate_children(children: t.Any) -> None:
if isinstance(children, _KnownValidChildren):
return

if isinstance(children, Iterable) and not isinstance(children, _KnownInvalidChildren):
for child in children: # pyright: ignore [reportUnknownVariableType]
_validate_children(child)
return

raise ValueError(f"{children!r} is not a valid child element")


class Element(BaseElement):
def __getitem__(self: ElementSelf, children: Node) -> ElementSelf:
return self.__class__(self._name, self._attrs, children)
_validate_children(children)
return self.__class__(self._name, self._attrs, children) # pyright: ignore [reportUnknownArgumentType]


class HTMLElement(Element):
Expand Down Expand Up @@ -457,3 +473,18 @@ def __html__(self) -> str: ...
u = Element("u")
ul = Element("ul")
var = Element("var")


_KnownInvalidChildren: UnionType = bytes | bytearray | memoryview

_KnownValidChildren: UnionType = ( # pyright: ignore [reportUnknownVariableType]
None
| BaseElement
| ContextProvider # pyright: ignore [reportMissingTypeArgument]
| ContextConsumer # pyright: ignore [reportMissingTypeArgument]
| Callable # pyright: ignore [reportMissingTypeArgument]
| str
| int
| Generator # pyright: ignore [reportMissingTypeArgument]
| _HasHtml
)
72 changes: 69 additions & 3 deletions tests/test_children.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
from __future__ import annotations

import dataclasses
import datetime
import decimal
import pathlib
import re
import typing as t

import pytest
Expand Down Expand Up @@ -202,7 +207,68 @@ def test_callable_in_generator() -> None:
assert str(div[((lambda: "hi") for _ in range(1))]) == "<div>hi</div>"


@pytest.mark.parametrize("not_a_child", [12.34, b"foo", object(), object])
def test_invalid_child(not_a_child: t.Any) -> None:
@dataclasses.dataclass
class MyDataClass:
name: str


class SomeClass:
pass


# Various types that are not valid children.
_invalid_children = [
12.34,
decimal.Decimal("12.34"),
complex("+1.23"),
object(),
datetime.date(1, 2, 3),
datetime.datetime(1, 2, 3),
datetime.time(1, 2),
b"foo",
bytearray(b"foo"),
memoryview(b"foo"),
Exception("foo"),
Ellipsis,
re.compile("foo"),
pathlib.Path("FOO"),
re, # module type
MyDataClass(name="Andreas"),
SomeClass(),
]


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_direct(not_a_child: t.Any) -> None:
with pytest.raises(ValueError, match="is not a valid child element"):
div[not_a_child]


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_nested_iterable(not_a_child: t.Any) -> None:
with pytest.raises(ValueError, match="is not a valid child element"):
div[[not_a_child]]


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_lazy_callable(not_a_child: t.Any) -> None:
"""
Ensure proper exception is raised for lazily evaluated invalid children.
"""
element = div[lambda: not_a_child]
with pytest.raises(ValueError, match="is not a valid child element"):
str(element)


@pytest.mark.parametrize("not_a_child", _invalid_children)
def test_invalid_child_lazy_generator(not_a_child: t.Any) -> None:
"""
Ensure proper exception is raised for lazily evaluated invalid children.
"""

def gen() -> t.Any:
yield not_a_child

element = div[gen()]
with pytest.raises(ValueError, match="is not a valid child element"):
str(div[not_a_child])
str(element)

0 comments on commit 909ecb8

Please sign in to comment.