Skip to content

Commit 0a1f02d

Browse files
fix: avoid dependeing on argparse_dataclass fork
1 parent 7f62bb9 commit 0a1f02d

File tree

3 files changed

+111
-3
lines changed

3 files changed

+111
-3
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ readme = "README.md"
88
version = "2.0.0"
99

1010
[tool.poetry.dependencies]
11-
argparse-dataclass = {git = "https://github.com/johanneskoester/argparse_dataclass.git"}
11+
argparse-dataclass = "^2.0.0"
1212
python = "^3.9"
1313
throttler = "^1.2.2"
1414
snakemake-interface-common = "^1.3.1"

snakemake_interface_executor_plugins/_common.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,110 @@
33
__email__ = "[email protected]"
44
__license__ = "MIT"
55

6+
from dataclasses import MISSING, Field
7+
from types import NoneType
8+
from typing import Any, Dict, List, Literal, Tuple, Union, get_args, get_origin
9+
from argparse_dataclass import _handle_bool_type
10+
11+
612
executor_plugin_prefix = "snakemake-executor-plugin-"
713
executor_plugin_module_prefix = executor_plugin_prefix.replace("-", "_")
14+
15+
16+
# Taken from https://github.com/mivade/argparse_dataclass/pull/59/files.
17+
# TODO remove once https://github.com/mivade/argparse_dataclass/pull/59/files is
18+
# merged and released
19+
def dataclass_field_to_argument_args(
20+
field: Field[Any],
21+
) -> Tuple[List[str], Dict[str, Any]]:
22+
"""Extract kwargs of ArgumentParser.add_argument from a dataclass field.
23+
24+
Returns pair of (args, kwargs) to be passed to ArgumentParser.add_argument.
25+
"""
26+
args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"])
27+
positional = not args[0].startswith("-")
28+
kwargs = {
29+
"type": field.metadata.get("type", field.type),
30+
"help": field.metadata.get("help", None),
31+
}
32+
33+
if field.metadata.get("args") and not positional:
34+
# We want to ensure that we store the argument based on the
35+
# name of the field and not whatever flag name was provided
36+
kwargs["dest"] = field.name
37+
38+
if field.metadata.get("choices") is not None:
39+
kwargs["choices"] = field.metadata["choices"]
40+
41+
# Support Literal types as an alternative means of specifying choices.
42+
if get_origin(field.type) is Literal:
43+
# Prohibit a potential collision with the choices field
44+
if field.metadata.get("choices") is not None:
45+
raise ValueError(
46+
f"Cannot infer type of items in field: {field.name}. "
47+
"Literal type arguments should not be combined with choices in the "
48+
"metadata. "
49+
"Remove the redundant choices field from the metadata."
50+
)
51+
52+
# Get the types of the arguments of the Literal
53+
types = [type(arg) for arg in get_args(field.type)]
54+
55+
# Make sure just a single type has been used
56+
if len(set(types)) > 1:
57+
raise ValueError(
58+
f"Cannot infer type of items in field: {field.name}. "
59+
"Literal type arguments should contain choices of a single type. "
60+
f"Instead, {len(set(types))} types where found: "
61+
+ ", ".join([type_.__name__ for type_ in set(types)])
62+
+ "."
63+
)
64+
65+
# Overwrite the type kwarg
66+
kwargs["type"] = types[0]
67+
# Use the literal arguments as choices
68+
kwargs["choices"] = get_args(field.type)
69+
70+
if field.metadata.get("metavar") is not None:
71+
kwargs["metavar"] = field.metadata["metavar"]
72+
73+
if field.metadata.get("nargs") is not None:
74+
kwargs["nargs"] = field.metadata["nargs"]
75+
if field.metadata.get("type") is None:
76+
# When nargs is specified, field.type should be a list,
77+
# or something equivalent, like typing.List.
78+
# Using it would most likely result in an error, so if the user
79+
# did not specify the type of the elements within the list, we
80+
# try to infer it:
81+
try:
82+
kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple
83+
except IndexError:
84+
# get_args returned an empty tuple, type cannot be inferred
85+
raise ValueError(
86+
f"Cannot infer type of items in field: {field.name}. "
87+
"Try using a parameterized type hint, or "
88+
"specifying the type explicitly using metadata['type']"
89+
)
90+
91+
if field.default == field.default_factory == MISSING and not positional:
92+
kwargs["required"] = True
93+
else:
94+
kwargs["default"] = MISSING
95+
96+
if field.type is bool:
97+
_handle_bool_type(field, args, kwargs)
98+
elif get_origin(field.type) is Union:
99+
if field.metadata.get("type") is None:
100+
# Optional[X] is equivalent to Union[X, None].
101+
f_args = get_args(field.type)
102+
if len(f_args) == 2 and NoneType in f_args:
103+
arg = next(a for a in f_args if a is not NoneType)
104+
kwargs["type"] = arg
105+
else:
106+
raise TypeError(
107+
"For Union types other than 'Optional', a custom 'type' must be "
108+
"specified using "
109+
"'metadata'."
110+
)
111+
112+
return args, kwargs

snakemake_interface_executor_plugins/registry/plugin.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
__email__ = "[email protected]"
44
__license__ = "MIT"
55

6-
from argparse_dataclass import field_to_argument_args, fields
6+
from argparse_dataclass import fields
77
from dataclasses import MISSING, Field, dataclass
88
from typing import Any, Optional, Type
99
import copy
@@ -12,6 +12,9 @@
1212
from snakemake_interface_common.exceptions import WorkflowError
1313

1414
from snakemake_interface_executor_plugins.exceptions import InvalidPluginException
15+
from snakemake_interface_executor_plugins._common import (
16+
dataclass_field_to_argument_args,
17+
)
1518

1619
# Valid Argument types (to distinguish from empty dataclasses)
1720
ArgTypes = (str, int, float, bool, list)
@@ -67,7 +70,7 @@ def register_cli_args(self, argparser):
6770
settings = argparser.add_argument_group(f"{self.name} executor settings")
6871

6972
for field in fields(dc):
70-
args, kwargs = field_to_argument_args(field)
73+
args, kwargs = dataclass_field_to_argument_args(field)
7174

7275
if field.metadata.get("env_var"):
7376
kwargs["env_var"] = f"SNAKEMAKE_{prefixed_name.upper()}"

0 commit comments

Comments
 (0)