diff --git a/dataclass_csv/dataclass_reader.py b/dataclass_csv/dataclass_reader.py index 6467c21..2579876 100644 --- a/dataclass_csv/dataclass_reader.py +++ b/dataclass_csv/dataclass_reader.py @@ -3,7 +3,7 @@ from datetime import date, datetime from distutils.util import strtobool -from typing import Union, Type, Optional, Sequence, Dict, Any, List +from typing import Union, Type, Optional, Sequence, Dict, Any, List, Generic, TypeVar import typing @@ -12,6 +12,8 @@ from collections import Counter +T = TypeVar("T") + def _verify_duplicate_header_items(header): if header is not None and len(header) == 0: @@ -45,11 +47,11 @@ def get_args(t): return tuple() -class DataclassReader: +class DataclassReader(Generic[T]): def __init__( self, f: Any, - cls: Type[object], + cls: Type[T], fieldnames: Optional[Sequence[str]] = None, restkey: Optional[str] = None, restval: Optional[Any] = None, @@ -183,7 +185,7 @@ def _parse_date_value(self, field, date_value, field_type): else: return datetime_obj - def _process_row(self, row): + def _process_row(self, row) -> T: values = dict() for field in dataclasses.fields(self._cls): @@ -242,7 +244,7 @@ def _process_row(self, row): values[field.name] = transformed_value return self._cls(**values) - def __next__(self): + def __next__(self) -> T: row = next(self._reader) return self._process_row(row) diff --git a/dataclass_csv/dataclass_reader.pyi b/dataclass_csv/dataclass_reader.pyi index 155bd11..fb721c9 100644 --- a/dataclass_csv/dataclass_reader.pyi +++ b/dataclass_csv/dataclass_reader.pyi @@ -1,11 +1,14 @@ from .field_mapper import FieldMapper as FieldMapper -from typing import Any, Optional, Sequence, Type +from typing import Any, Optional, Sequence, Type, Generic, TypeVar -class DataclassReader: +T = TypeVar("T") + + +class DataclassReader(Generic[T]): def __init__( self, f: Any, - cls: Type[object], + cls: Type[T], fieldnames: Optional[Sequence[str]] = ..., restkey: Optional[str] = ..., restval: Optional[Any] = ..., @@ -13,6 +16,6 @@ class DataclassReader: *args: Any, **kwds: Any ) -> None: ... - def __next__(self) -> None: ... + def __next__(self) -> T: ... def __iter__(self) -> Any: ... def map(self, csv_fieldname: str) -> FieldMapper: ...