Skip to content
60 changes: 58 additions & 2 deletions cssselect/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,33 @@ class Function:
Represents selector:name(expr)
"""

def __init__(self, selector: Tree, name: str, arguments: Sequence["Token"]) -> None:
def __init__(
self,
selector: Tree,
name: str,
arguments: Sequence["Token"],
of_type: Optional[List[Selector]] = None,
) -> None:
self.selector = selector
self.name = ascii_lower(name)
self.arguments = arguments

# for css4 :nth-child(An+B of Subselector)
self.of_type: Optional[Selector]
if of_type:
self.of_type = of_type[0]
else:
self.of_type = None

def __repr__(self) -> str:
if self.of_type:
return "%s[%r:%s(%r of %s)]" % (
self.__class__.__name__,
self.selector,
self.name,
[token.value for token in self.arguments],
self.of_type.__repr__(),
)
return "%s[%r:%s(%r)]" % (
self.__class__.__name__,
self.selector,
Expand Down Expand Up @@ -695,7 +716,8 @@ def parse_simple_selector(
selectors = parse_simple_selector_arguments(stream)
result = SpecificityAdjustment(result, selectors)
else:
result = Function(result, ident, parse_arguments(stream))
fn_arguments, of_type = parse_function_arguments(stream)
result = Function(result, ident, fn_arguments, of_type)
else:
raise SelectorSyntaxError("Expected selector, got %s" % (peek,))
if len(stream.used) == selector_start:
Expand All @@ -716,6 +738,29 @@ def parse_arguments(stream: "TokenStream") -> List["Token"]:
raise SelectorSyntaxError("Expected an argument, got %s" % (next,))


def parse_function_arguments(
stream: "TokenStream",
) -> Tuple[List["Token"], Optional[List[Selector]]]:
arguments: List["Token"] = []
while 1:
stream.skip_whitespace()
next = stream.next()
if next == ("IDENT", "of"):
stream.skip_whitespace()
of_type = parse_of_type(stream)
return arguments, of_type
elif next.type in ("IDENT", "STRING", "NUMBER") or next in [
("DELIM", "+"),
("DELIM", "-"),
]:
arguments.append(next)
elif next == ("DELIM", ")"):
return arguments, None

else:
raise SelectorSyntaxError("Expected an argument, got %s" % (next,))


def parse_relative_selector(stream: "TokenStream") -> Tuple["Token", Selector]:
stream.skip_whitespace()
subselector = ""
Expand Down Expand Up @@ -761,6 +806,17 @@ def parse_simple_selector_arguments(stream: "TokenStream") -> List[Tree]:
return arguments


def parse_of_type(stream: "TokenStream") -> List[Selector]:
subselector = ""
while 1:
next = stream.next()
if next == ("DELIM", ")"):
break
subselector += typing.cast(str, next.value)
result = parse(subselector)
return result


def parse_attrib(selector: Tree, stream: "TokenStream") -> Attrib:
stream.skip_whitespace()
attrib = stream.next_ident_or_star()
Expand Down
4 changes: 3 additions & 1 deletion cssselect/xpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -515,7 +515,9 @@ def xpath_nth_child_function(
# `add_name_test` boolean is inverted and somewhat counter-intuitive:
#
# nth_of_type() calls nth_child(add_name_test=False)
if add_name_test:
if function.of_type:
nodetest = str(self.xpath(function.of_type.parsed_tree))
elif add_name_test:
nodetest = "*"
else:
nodetest = "%s" % xpath.element
Expand Down
1 change: 1 addition & 0 deletions pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ disable=assignment-from-no-return,
too-many-branches,
too-many-function-args,
too-many-lines,
too-many-locals,
too-many-public-methods,
too-many-statements,
undefined-variable,
Expand Down
11 changes: 11 additions & 0 deletions tests/test_cssselect.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,14 @@ def xpath(css: str) -> str:
)

# --- nth-* and nth-last-* -------------------------------------
assert xpath("e:nth-child(2n+1 of S)") == "e[count(preceding-sibling::S) mod 2 = 0]"
assert xpath("e:nth-of-type(2n+1 of S)") == "e[count(preceding-sibling::S) mod 2 = 0]"
assert (
xpath("e:nth-child(2n+1 of li.important)") == "e[count(preceding-sibling::li[@class"
" and contains(concat(' ', normalize-space(@class), ' '), ' important ')])"
" mod 2 = 0]"
)

assert xpath("e:nth-child(1)") == ("e[count(preceding-sibling::*) = 0]")

# always true
Expand Down Expand Up @@ -503,6 +511,9 @@ def xpath(css: str) -> str:
assert xpath("e ~ f:nth-child(3)") == (
"e/following-sibling::f[count(preceding-sibling::*) = 2]"
)
assert xpath("e ~ f:nth-child(3 of S)") == (
"e/following-sibling::f[count(preceding-sibling::S) = 2]"
)
assert xpath("div#container p") == ("div[@id = 'container']/descendant-or-self::*/p")
assert xpath("e:where(foo)") == "e[name() = 'foo']"
assert xpath("e:where(foo, bar)") == "e[(name() = 'foo') or (name() = 'bar')]"
Expand Down