From 7e33dc928580f31e74ed5ba91b14780442fb0fe0 Mon Sep 17 00:00:00 2001 From: Daniel Jackson <52429462+danjjackson@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:02:30 +0100 Subject: [PATCH 1/3] Update dataclass_reader.py Make DataclassReader a Generic class whose type is defined by the type of dataclass supplied to the init function --- dataclass_csv/dataclass_reader.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/dataclass_csv/dataclass_reader.py b/dataclass_csv/dataclass_reader.py index 6467c21..2efcf92 100644 --- a/dataclass_csv/dataclass_reader.py +++ b/dataclass_csv/dataclass_reader.py @@ -3,7 +3,7 @@ from datetime import date, datetime from distutils.util import strtobool -from typing import Union, Type, Optional, Sequence, Dict, Any, List +from typing import Union, Type, Optional, Sequence, Dict, Any, List, Generic, TypeVar import typing @@ -12,6 +12,7 @@ from collections import Counter +T = TypeVar("T") def _verify_duplicate_header_items(header): if header is not None and len(header) == 0: @@ -45,11 +46,11 @@ def get_args(t): return tuple() -class DataclassReader: +class DataclassReader(Generic[T]): def __init__( self, f: Any, - cls: Type[object], + cls: Type[T], fieldnames: Optional[Sequence[str]] = None, restkey: Optional[str] = None, restval: Optional[Any] = None, @@ -183,7 +184,7 @@ def _parse_date_value(self, field, date_value, field_type): else: return datetime_obj - def _process_row(self, row): + def _process_row(self, row) -> T: values = dict() for field in dataclasses.fields(self._cls): @@ -242,7 +243,7 @@ def _process_row(self, row): values[field.name] = transformed_value return self._cls(**values) - def __next__(self): + def __next__(self) -> T: row = next(self._reader) return self._process_row(row) From de262a9a2fe4e36a09f0a3632ca72f9bc65fdedf Mon Sep 17 00:00:00 2001 From: Daniel Jackson <52429462+danjjackson@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:03:08 +0100 Subject: [PATCH 2/3] Fix formatting --- dataclass_csv/dataclass_reader.py | 1 + 1 file changed, 1 insertion(+) diff --git a/dataclass_csv/dataclass_reader.py b/dataclass_csv/dataclass_reader.py index 2efcf92..2579876 100644 --- a/dataclass_csv/dataclass_reader.py +++ b/dataclass_csv/dataclass_reader.py @@ -14,6 +14,7 @@ T = TypeVar("T") + def _verify_duplicate_header_items(header): if header is not None and len(header) == 0: return From 93fec02ff123de6702ed7d6b1549a716c3c394f0 Mon Sep 17 00:00:00 2001 From: Daniel Jackson <52429462+danjjackson@users.noreply.github.com> Date: Tue, 30 Jan 2024 16:06:51 +0100 Subject: [PATCH 3/3] Make a generic class --- dataclass_csv/dataclass_reader.pyi | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/dataclass_csv/dataclass_reader.pyi b/dataclass_csv/dataclass_reader.pyi index 155bd11..fb721c9 100644 --- a/dataclass_csv/dataclass_reader.pyi +++ b/dataclass_csv/dataclass_reader.pyi @@ -1,11 +1,14 @@ from .field_mapper import FieldMapper as FieldMapper -from typing import Any, Optional, Sequence, Type +from typing import Any, Optional, Sequence, Type, Generic, TypeVar -class DataclassReader: +T = TypeVar("T") + + +class DataclassReader(Generic[T]): def __init__( self, f: Any, - cls: Type[object], + cls: Type[T], fieldnames: Optional[Sequence[str]] = ..., restkey: Optional[str] = ..., restval: Optional[Any] = ..., @@ -13,6 +16,6 @@ class DataclassReader: *args: Any, **kwds: Any ) -> None: ... - def __next__(self) -> None: ... + def __next__(self) -> T: ... def __iter__(self) -> Any: ... def map(self, csv_fieldname: str) -> FieldMapper: ...