|
3 | 3 | |
4 | 4 | __license__ = "MIT" |
5 | 5 |
|
| 6 | +from dataclasses import MISSING, Field |
| 7 | +from types import NoneType |
| 8 | +from typing import Any, Dict, List, Literal, Tuple, Union, get_args, get_origin |
| 9 | +from argparse_dataclass import _handle_bool_type |
| 10 | + |
| 11 | + |
6 | 12 | executor_plugin_prefix = "snakemake-executor-plugin-" |
7 | 13 | executor_plugin_module_prefix = executor_plugin_prefix.replace("-", "_") |
| 14 | + |
| 15 | + |
| 16 | +# Taken from https://github.com/mivade/argparse_dataclass/pull/59/files. |
| 17 | +# TODO remove once https://github.com/mivade/argparse_dataclass/pull/59/files is |
| 18 | +# merged and released |
| 19 | +def dataclass_field_to_argument_args( |
| 20 | + field: Field[Any], |
| 21 | +) -> Tuple[List[str], Dict[str, Any]]: |
| 22 | + """Extract kwargs of ArgumentParser.add_argument from a dataclass field. |
| 23 | +
|
| 24 | + Returns pair of (args, kwargs) to be passed to ArgumentParser.add_argument. |
| 25 | + """ |
| 26 | + args = field.metadata.get("args", [f"--{field.name.replace('_', '-')}"]) |
| 27 | + positional = not args[0].startswith("-") |
| 28 | + kwargs = { |
| 29 | + "type": field.metadata.get("type", field.type), |
| 30 | + "help": field.metadata.get("help", None), |
| 31 | + } |
| 32 | + |
| 33 | + if field.metadata.get("args") and not positional: |
| 34 | + # We want to ensure that we store the argument based on the |
| 35 | + # name of the field and not whatever flag name was provided |
| 36 | + kwargs["dest"] = field.name |
| 37 | + |
| 38 | + if field.metadata.get("choices") is not None: |
| 39 | + kwargs["choices"] = field.metadata["choices"] |
| 40 | + |
| 41 | + # Support Literal types as an alternative means of specifying choices. |
| 42 | + if get_origin(field.type) is Literal: |
| 43 | + # Prohibit a potential collision with the choices field |
| 44 | + if field.metadata.get("choices") is not None: |
| 45 | + raise ValueError( |
| 46 | + f"Cannot infer type of items in field: {field.name}. " |
| 47 | + "Literal type arguments should not be combined with choices in the " |
| 48 | + "metadata. " |
| 49 | + "Remove the redundant choices field from the metadata." |
| 50 | + ) |
| 51 | + |
| 52 | + # Get the types of the arguments of the Literal |
| 53 | + types = [type(arg) for arg in get_args(field.type)] |
| 54 | + |
| 55 | + # Make sure just a single type has been used |
| 56 | + if len(set(types)) > 1: |
| 57 | + raise ValueError( |
| 58 | + f"Cannot infer type of items in field: {field.name}. " |
| 59 | + "Literal type arguments should contain choices of a single type. " |
| 60 | + f"Instead, {len(set(types))} types where found: " |
| 61 | + + ", ".join([type_.__name__ for type_ in set(types)]) |
| 62 | + + "." |
| 63 | + ) |
| 64 | + |
| 65 | + # Overwrite the type kwarg |
| 66 | + kwargs["type"] = types[0] |
| 67 | + # Use the literal arguments as choices |
| 68 | + kwargs["choices"] = get_args(field.type) |
| 69 | + |
| 70 | + if field.metadata.get("metavar") is not None: |
| 71 | + kwargs["metavar"] = field.metadata["metavar"] |
| 72 | + |
| 73 | + if field.metadata.get("nargs") is not None: |
| 74 | + kwargs["nargs"] = field.metadata["nargs"] |
| 75 | + if field.metadata.get("type") is None: |
| 76 | + # When nargs is specified, field.type should be a list, |
| 77 | + # or something equivalent, like typing.List. |
| 78 | + # Using it would most likely result in an error, so if the user |
| 79 | + # did not specify the type of the elements within the list, we |
| 80 | + # try to infer it: |
| 81 | + try: |
| 82 | + kwargs["type"] = get_args(field.type)[0] # get_args returns a tuple |
| 83 | + except IndexError: |
| 84 | + # get_args returned an empty tuple, type cannot be inferred |
| 85 | + raise ValueError( |
| 86 | + f"Cannot infer type of items in field: {field.name}. " |
| 87 | + "Try using a parameterized type hint, or " |
| 88 | + "specifying the type explicitly using metadata['type']" |
| 89 | + ) |
| 90 | + |
| 91 | + if field.default == field.default_factory == MISSING and not positional: |
| 92 | + kwargs["required"] = True |
| 93 | + else: |
| 94 | + kwargs["default"] = MISSING |
| 95 | + |
| 96 | + if field.type is bool: |
| 97 | + _handle_bool_type(field, args, kwargs) |
| 98 | + elif get_origin(field.type) is Union: |
| 99 | + if field.metadata.get("type") is None: |
| 100 | + # Optional[X] is equivalent to Union[X, None]. |
| 101 | + f_args = get_args(field.type) |
| 102 | + if len(f_args) == 2 and NoneType in f_args: |
| 103 | + arg = next(a for a in f_args if a is not NoneType) |
| 104 | + kwargs["type"] = arg |
| 105 | + else: |
| 106 | + raise TypeError( |
| 107 | + "For Union types other than 'Optional', a custom 'type' must be " |
| 108 | + "specified using " |
| 109 | + "'metadata'." |
| 110 | + ) |
| 111 | + |
| 112 | + return args, kwargs |
0 commit comments