Skip to content

Commit a6fd365

Browse files
committed
💥 MultiVar -> Field.multiple
1 parent ae3c818 commit a6fd365

File tree

7 files changed

+73
-109
lines changed

7 files changed

+73
-109
lines changed

src/arclet/alconna/__init__.py

-3
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,6 @@
3838
from .manager import ShortcutArgs as ShortcutArgs
3939
from .manager import command_manager as command_manager
4040
from .typing import AllParam as AllParam
41-
from .typing import MultiVar as MultiVar
42-
from .typing import Nargs as Nargs
43-
from .typing import StrMulti as StrMulti
4441

4542
__version__ = "1.8.31"
4643

src/arclet/alconna/args.py

+28-13
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,15 @@
44
import dataclasses as dc
55
import re
66
import typing
7-
from typing import Any, Callable, Generic, TypeVar, ClassVar, ForwardRef, Final, TYPE_CHECKING
7+
from typing import Any, Callable, Generic, Literal, TypeVar, ClassVar, ForwardRef, Final, TYPE_CHECKING, get_origin, get_args
88
from typing_extensions import dataclass_transform, ParamSpec, Concatenate, TypeAlias
99

10-
from nepattern import NONE, BasePattern, RawStr, UnionPattern, parser
10+
from nepattern import NONE, BasePattern, RawStr, UnionPattern, parser, STRING
1111
from tarina import Empty, lang
1212

1313
from ._dcls import safe_dcls_kw, safe_field_kw
1414
from .exceptions import InvalidArgs
15-
from .typing import MultiVar, TAValue, parent_frame_namespace, merge_cls_and_parent_ns
15+
from .typing import TAValue, parent_frame_namespace, merge_cls_and_parent_ns
1616

1717
_T = TypeVar("_T")
1818

@@ -40,6 +40,8 @@ class Field(Generic[_T]):
4040
optional: bool = dc.field(default=False, compare=False, hash=False)
4141
hidden: bool = dc.field(default=False, compare=False, hash=False)
4242
kw_only: bool = dc.field(default=False, compare=False, hash=False)
43+
multiple: bool | int | Literal["+", "*", "str"] = dc.field(default=False, compare=False, hash=False)
44+
kw_sep: str = dc.field(default="=", compare=False, hash=False)
4345

4446
@property
4547
def display(self):
@@ -87,11 +89,13 @@ def arg_field(
8789
missing_tips: Callable[[], str] | None = None,
8890
notice: str | None = None,
8991
seps: str = " ",
92+
multiple: bool | int | Literal["+", "*", "str"] = False,
9093
kw_only: bool = False,
94+
kw_sep: str = "=",
9195
optional: bool = False,
9296
hidden: bool = False,
9397
) -> "Any":
94-
return Field(default, default_factory, alias, completion, unmatch_tips, missing_tips, notice, seps, optional, hidden, kw_only)
98+
return Field(default, default_factory, alias, completion, unmatch_tips, missing_tips, notice, seps, optional, hidden, kw_only, multiple, kw_sep)
9599

96100

97101
@dc.dataclass(**safe_dcls_kw(init=False, eq=True, unsafe_hash=True, slots=True))
@@ -167,8 +171,8 @@ def __init__(self, args: list[Arg[Any]], origin: type[ArgsBase] | None = None):
167171
self.data = args
168172
self.normal: list[Arg[Any]] = []
169173
self.keyword_only: dict[str, Arg[Any]] = {}
170-
self.vars_positional: list[tuple[MultiVar, Arg[Any]]] = []
171-
self.vars_keyword: list[tuple[MultiVar, Arg[Any]]] = []
174+
self.vars_positional: list[tuple[int | Literal["+", "*", "str"], Arg[Any]]] = []
175+
self.vars_keyword: list[tuple[str, Arg[Any]]] = []
172176
self._visit = set()
173177
self.optional_count = 0
174178
self.__check_vars__()
@@ -189,17 +193,20 @@ def __check_vars__(self):
189193
if arg.name in self._visit:
190194
continue
191195
self._visit.add(arg.name)
192-
if isinstance(arg.type_, MultiVar):
196+
if arg.field.multiple is not False:
193197
if arg.field.kw_only:
194-
# for slot in self.vars_positional:
195-
# _, a = slot
196-
# if arg.type_.base.sep in a.field.seps:
197-
# raise InvalidArgs("varkey cannot use the same sep as varpos's Arg")
198-
self.vars_keyword.append((arg.type_, arg))
198+
for slot in self.vars_positional:
199+
_, a = slot
200+
if arg.field.kw_sep in a.field.seps:
201+
raise InvalidArgs("varkey cannot use the same sep as varpos's Arg")
202+
self.vars_keyword.append((arg.field.kw_sep, arg))
199203
elif self.keyword_only:
200204
raise InvalidArgs(lang.require("args", "exclude_mutable_args"))
201205
else:
202-
self.vars_positional.append((arg.type_, arg))
206+
flag = arg.field.multiple
207+
if flag is True:
208+
flag = "+"
209+
self.vars_positional.append((flag, arg))
203210
elif arg.field.kw_only:
204211
if self.vars_keyword:
205212
raise InvalidArgs(lang.require("args", "exclude_mutable_args"))
@@ -328,6 +335,14 @@ def __new__(
328335
field = Field(field)
329336
if field.default is Empty and field.default_factory is Empty:
330337
delattr(cls, name)
338+
if field.multiple is not False:
339+
if not field.kw_only:
340+
if get_origin(typ) is tuple:
341+
typ = get_args(typ)[0]
342+
elif field.multiple != "str" or typ is not str:
343+
raise TypeError(f"{name!r} is a varpos but does not have a tuple type annotation")
344+
elif get_origin(typ) is not dict:
345+
raise TypeError(f"{name!r} is a varkey but does not have a dict type annotation")
331346
cls_args.append(Arg(name, typ, field))
332347
for name, value in cls.__dict__.items():
333348
if isinstance(value, Field) and name not in cls_annotations:

src/arclet/alconna/ingedia/_handlers.py

+24-20
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import re
4-
from typing import TYPE_CHECKING, Any, Iterable
4+
from typing import TYPE_CHECKING, Any, Iterable, Literal
55

66
from nepattern import ANY, STRING, AnyString, BasePattern
77
from tarina import Empty, lang, safe_eval, split_once
@@ -19,7 +19,7 @@
1919
PauseTriggered,
2020
ParamsUnmatched,
2121
)
22-
from ..typing import KWBool, MultiVar, _AllParamPattern, _StrMulti
22+
from ..typing import KWBool, _AllParamPattern
2323

2424
from ._util import levenshtein
2525

@@ -74,12 +74,14 @@ def _validate(argv: Argv, target: Arg[Any], value: BasePattern[Any, Any, Any], r
7474
result[target.name] = res._value # noqa
7575

7676

77-
def step_varpos(argv: Argv, args: _Args, slot: tuple[MultiVar, Arg], result: dict[str, Any]):
78-
value, arg = slot
77+
def step_varpos(argv: Argv, args: _Args, slot: tuple[int | Literal["+", "*", "str"], Arg], result: dict[str, Any]):
78+
flag, arg = slot
79+
value = arg.type_
7980
key = arg.name
81+
length = int(flag) if flag.__class__ is int else -1
8082
default_val = arg.field.default
8183
_result = []
82-
kwonly_seps = "".join([arg.type_.sep for arg in args.keyword_only.values()]) # type: ignore
84+
kwonly_seps = "".join([arg.field.kw_sep for arg in args.keyword_only.values()])
8385
count = 0
8486
while argv.current_index != argv.ndata:
8587
may_arg, _str = argv.next(arg.field.seps)
@@ -91,33 +93,36 @@ def step_varpos(argv: Argv, args: _Args, slot: tuple[MultiVar, Arg], result: dic
9193
if _str and kwonly_seps and split_once(pat.match(may_arg)["name"], kwonly_seps, argv.filter_crlf)[0] in args.keyword_only: # noqa: E501 # type: ignore
9294
argv.rollback(may_arg)
9395
break
94-
if _str and args.vars_keyword and "=" in may_arg: #args.vars_keyword[0][0].base.sep in may_arg:
96+
if _str and args.vars_keyword and args.vars_keyword[0][0] in may_arg:
9597
argv.rollback(may_arg)
9698
break
97-
if (res := value.base.validate(may_arg)).flag != "valid":
99+
if (res := value.validate(may_arg)).flag != "valid":
98100
argv.rollback(may_arg)
99101
break
100102
_result.append(res._value) # noqa
101103
count += 1
102-
if 0 < value.length <= count:
104+
if 0 < length <= count:
103105
break
104106
if not _result:
105107
if default_val is not Empty:
106108
_result = default_val if isinstance(default_val, Iterable) else ()
107-
elif value.flag == "*":
109+
elif flag == "*":
108110
_result = ()
109111
elif arg.field.optional:
110112
return
111113
else:
112114
raise ArgumentMissing(arg.field.get_missing_tips(lang.require("args", "missing").format(key=key)), arg)
113-
if isinstance(value, _StrMulti):
115+
if flag == "str":
114116
result[key] = arg.field.seps[0].join(_result)
115117
else:
116118
result[key] = tuple(_result)
117119

118120

119-
def step_varkey(argv: Argv, slot: tuple[MultiVar, Arg], result: dict[str, Any]):
120-
value, arg = slot
121+
def step_varkey(argv: Argv, slot: tuple[str, Arg], result: dict[str, Any]):
122+
kw_sep, arg = slot
123+
flag = arg.field.multiple
124+
length = int(flag) if flag.__class__ is int else -1
125+
value = arg.type_
121126
name = arg.name
122127
default_val = arg.field.default
123128
_result = {}
@@ -129,23 +134,23 @@ def step_varkey(argv: Argv, slot: tuple[MultiVar, Arg], result: dict[str, Any]):
129134
break
130135
if _str and may_arg in global_config.remainders:
131136
break
132-
if not (_kwarg := re.match(rf"^(-*[^{'='}]+){'='}(.*?)$", may_arg)):
137+
if not (_kwarg := re.match(rf"^(-*[^{kw_sep}]+){kw_sep}(.*?)$", may_arg)):
133138
argv.rollback(may_arg)
134139
break
135140
key = _kwarg[1]
136141
if not (_m_arg := _kwarg[2]):
137142
_m_arg, _ = argv.next(arg.field.seps)
138-
if (res := value.base.validate(_m_arg)).flag != "valid":
143+
if (res := value.validate(_m_arg)).flag != "valid":
139144
argv.rollback(may_arg)
140145
break
141146
_result[key] = res._value # noqa
142147
count += 1
143-
if 0 < value.length <= count:
148+
if 0 < length <= count:
144149
break
145150
if not _result:
146151
if default_val is not Empty:
147152
_result = default_val if isinstance(default_val, dict) else {}
148-
elif value.flag == "*":
153+
elif flag == "*":
149154
_result = {}
150155
elif arg.field.optional:
151156
return
@@ -158,8 +163,7 @@ def step_keyword(argv: Argv, args: _Args, result: dict[str, Any]):
158163
kwonly_seps = set()
159164
for arg in args.keyword_only.values():
160165
kwonly_seps.update(arg.field.seps)
161-
# kwonly_seps1 = "".join([arg.type_.sep for arg in args.keyword_only.values()]) # type: ignore
162-
kwonly_seps1 = "="
166+
kwonly_seps1 = "".join([arg.field.kw_sep for arg in args.keyword_only.values()])
163167
target = len(args.keyword_only)
164168
count = 0
165169
while count < target:
@@ -181,7 +185,7 @@ def step_keyword(argv: Argv, args: _Args, result: dict[str, Any]):
181185
):
182186
break
183187
for arg in args.keyword_only.values():
184-
if arg.type_.validate(may_arg).flag == "valid": # type: ignore
188+
if arg.type_.validate(may_arg).flag == "valid":
185189
raise InvalidParam(lang.require("args", "key_missing").format(target=may_arg, key=arg.name), arg)
186190
for name in args.keyword_only:
187191
if levenshtein(_key, name) >= argv.fuzzy_threshold:
@@ -193,7 +197,7 @@ def step_keyword(argv: Argv, args: _Args, result: dict[str, Any]):
193197
if isinstance(value, KWBool):
194198
_m_arg = key
195199
else:
196-
_m_arg, _ = argv.next("=") # (args.keyword_only[_key].separators)
200+
_m_arg, _ = argv.next(args.keyword_only[_key].separators)
197201
_validate(argv, arg, value, result, _m_arg, _str)
198202
count += 1
199203

src/arclet/alconna/typing.py

-48
Original file line numberDiff line numberDiff line change
@@ -103,58 +103,10 @@ def __calc_eq__(self, other): # pragma: no cover
103103
AllParam: _AllParamPattern[Any] = _AllParamPattern()
104104

105105

106-
class MultiVar(BasePattern[T, Any, Literal[MatchMode.KEEP]]):
107-
"""对可变参数的包装"""
108-
109-
base: BasePattern[T, Any, Any]
110-
flag: Literal["+", "*"]
111-
length: int
112-
113-
def __init__(self, value: TAValue[T], flag: int | Literal["+", "*"] = "+"):
114-
"""构建一个可变参数
115-
116-
Args:
117-
value (type | BasePattern): 参数的值
118-
flag (int | Literal["+", "*"]): 参数的标记
119-
"""
120-
self.base = value if isinstance(value, BasePattern) else parser(value) # type: ignore
121-
assert isinstance(self.base, BasePattern)
122-
if not isinstance(flag, int):
123-
alias = f"({self.base}{flag})"
124-
self.flag = flag
125-
self.length = -1
126-
elif flag > 1:
127-
alias = f"({self.base}+)[:{flag}]"
128-
self.flag = "+"
129-
self.length = flag
130-
else: # pragma: no cover
131-
alias = str(self.base)
132-
self.flag = "+"
133-
self.length = 1
134-
super().__init__(mode=MatchMode.KEEP, origin=self.base.origin, alias=alias)
135-
136-
def __repr__(self):
137-
return self.alias
138-
139-
140-
Nargs = MultiVar
141-
142-
143106
class KWBool(BasePattern):
144107
"""对布尔参数的包装"""
145108

146109

147-
class _StrMulti(MultiVar[str]):
148-
pass
149-
150-
151-
StrMulti = _StrMulti(str)
152-
"""特殊参数, 用于匹配多个字符串, 并将结果通过 `str.join` 合并"""
153-
154-
StrMulti.alias = "str+"
155-
StrMulti.refresh()
156-
157-
158110
def parent_frame_namespace(*, parent_depth: int = 2, force: bool = False) -> dict[str, Any] | None:
159111
"""We allow use of items in parent namespace to get around the issue with `get_type_hints` only looking in the
160112
global module namespace. See https://github.com/pydantic/pydantic/issues/2678#issuecomment-1008139014 -> Scope

src/arclet/alconna/v1/__init__.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@
3939
from arclet.alconna.typing import AllParam as AllParam # noqa: F401
4040
# from arclet.alconna.typing import KeyWordVar as KeyWordVar # noqa: F401
4141
# from arclet.alconna.typing import Kw as Kw # noqa: F401
42-
from arclet.alconna.typing import MultiVar as MultiVar # noqa: F401
43-
from arclet.alconna.typing import Nargs as Nargs # noqa: F401
44-
from arclet.alconna.typing import StrMulti as StrMulti # noqa: F401
42+
# from arclet.alconna.typing import MultiVar as MultiVar # noqa: F401
43+
# from arclet.alconna.typing import Nargs as Nargs # noqa: F401
44+
# from arclet.alconna.typing import StrMulti as StrMulti # noqa: F401
4545
# from arclet.alconna.typing import UnpackVar as UnpackVar # noqa: F401
4646
# from arclet.alconna.typing import Up as Up # noqa: F401
4747

tests/args_test.py

+15-18
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from nepattern import INTEGER, BasePattern, MatchMode, combine
44

5-
from arclet.alconna import Args, Nargs, StrMulti
5+
from arclet.alconna import Args
66
from devtool import analyse_args
77

88

@@ -55,19 +55,19 @@ def test_object():
5555

5656

5757
def test_multi():
58-
arg8 = Args.multi(Nargs(str, "+"))
58+
arg8 = Args.multi(str, multiple=True)
5959
assert analyse_args(arg8, ["a b c d"]).get("multi") == ("a", "b", "c", "d")
6060
assert analyse_args(arg8, [], raise_exception=False) != {"multi": ()}
61-
arg8_1 = Args.kwargs(Nargs(str, "+"), kw_only=True)
61+
arg8_1 = Args.kwargs(str, multiple=True, kw_only=True)
6262
assert analyse_args(arg8_1, ["a=b c=d"]).get("kwargs") == {"a": "b", "c": "d"}
63-
arg8_2 = Args.multi(Nargs(int, "*"))
63+
arg8_2 = Args.multi(int, multiple="*")
6464
assert analyse_args(arg8_2, ["1 2 3 4"]).get("multi") == (1, 2, 3, 4)
6565
assert analyse_args(arg8_2, []).get("multi") == ()
66-
arg8_3 = Args.multi(Nargs(int, 3))
66+
arg8_3 = Args.multi(int, multiple=3)
6767
assert analyse_args(arg8_3, ["1 2 3"]).get("multi") == (1, 2, 3)
6868
assert analyse_args(arg8_3, ["1 2"]).get("multi") == (1, 2)
6969
assert analyse_args(arg8_3, ["1 2 3 4"]).get("multi") == (1, 2, 3)
70-
arg8_4 = Args.multi(Nargs(str, "*")).kwargs(Nargs(str, "*"), kw_only=True)
70+
arg8_4 = Args.multi(str, multiple="*").kwargs(str, multiple="*", kw_only=True)
7171
assert analyse_args(arg8_4, ["1 2 3 4 a=b c=d"]).get("multi") == ("1", "2", "3", "4")
7272
assert analyse_args(arg8_4, ["1 2 3 4 a=b c=d"]).get("kwargs") == {
7373
"a": "b",
@@ -128,13 +128,12 @@ def test_kwonly():
128128
"width": 960,
129129
"height": 480,
130130
}
131-
# FIXME: kw_sep
132-
# arg14_2 = Args["foo", str]["bar", KeyWordVar(int, " ")]["baz", KeyWordVar(bool, ":")]
133-
# assert analyse_args(arg14_2, ["abc baz:false -bar 123"]) == {
134-
# "bar": 123,
135-
# "baz": False,
136-
# "foo": "abc",
137-
# }
131+
arg14_2 = Args.foo(str).bar(int, kw_only=True, kw_sep=" ").baz(bool, kw_only=True, kw_sep=":")
132+
assert analyse_args(arg14_2, ["abc baz:false -bar 123"]) == {
133+
"bar": 123,
134+
"baz": False,
135+
"foo": "abc",
136+
}
138137
arg14_3 = Args.foo(str).bar(int, kw_only=True).baz(bool, kw_only=True)
139138
assert analyse_args(arg14_3, ["abc baz=false bar=456"]) == {
140139
"bar": 456,
@@ -220,16 +219,14 @@ def test_annotated():
220219

221220

222221
def test_multi_multi():
223-
from arclet.alconna.typing import MultiVar
224-
225-
arg20 = Args["foo", MultiVar(str)]["bar", MultiVar(int)]
222+
arg20 = Args.foo(str, multiple=True).bar(int, multiple=True)
226223
assert analyse_args(arg20, ["a b -- 1 2"]) == {"foo": ("a", "b"), "bar": (1, 2)}
227224

228-
arg20_1 = Args["foo", MultiVar(int)]["bar", MultiVar(str)]
225+
arg20_1 = Args.foo(int, multiple=True).bar(str, multiple=True)
229226
assert analyse_args(arg20_1, ["1 2 -- a b"]) == {"foo": (1, 2), "bar": ("a", "b")}
230227
assert analyse_args(arg20_1, ["1 2 a b"]) == {"foo": (1, 2), "bar": ("a", "b")}
231228

232-
arg20_2 = Args["foo", str]["bar", StrMulti]
229+
arg20_2 = Args.foo(str).bar(str, multiple="str")
233230
assert analyse_args(arg20_2, ["a b"]) == {"foo": "a", "bar": "b"}
234231
assert analyse_args(arg20_2, ["a b c"]) == {"foo": "a", "bar": "b c"}
235232

0 commit comments

Comments
 (0)