Skip to content

Commit b98ae5d

Browse files
authored
feat: improve create factory typing (#657)
1 parent da6ad4d commit b98ae5d

File tree

5 files changed

+40
-16
lines changed

5 files changed

+40
-16
lines changed

polyfactory/factories/base.py

Lines changed: 29 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@
3939
TypedDict,
4040
TypeVar,
4141
cast,
42+
overload,
4243
)
4344
from uuid import UUID
4445

4546
from faker import Faker
46-
from typing_extensions import get_args, get_origin, get_original_bases
47+
from typing_extensions import Self, get_args, get_origin, get_original_bases
4748

4849
from polyfactory.constants import (
4950
DEFAULT_RANDOM,
@@ -89,6 +90,7 @@
8990

9091

9192
T = TypeVar("T")
93+
U = TypeVar("U")
9294
F = TypeVar("F", bound="BaseFactory[Any]")
9395

9496

@@ -372,7 +374,7 @@ def _get_config(cls) -> dict[str, Any]:
372374
}
373375

374376
@classmethod
375-
def _get_or_create_factory(cls, model: type) -> type[BaseFactory[Any]]:
377+
def _get_or_create_factory(cls, model: type[U]) -> type[BaseFactory[U]]:
376378
"""Get a factory from registered factories or generate a factory dynamically.
377379
378380
:param model: A model type.
@@ -553,13 +555,31 @@ def _create_generic_fn() -> Callable:
553555
**(cls._extra_providers or {}),
554556
}
555557

558+
@overload
556559
@classmethod
557560
def create_factory(
558-
cls: type[F],
559-
model: type[T] | None = None,
561+
cls,
562+
model: None = None,
563+
bases: tuple[type[BaseFactory[Any]], ...] | None = None,
564+
**kwargs: Any,
565+
) -> type[Self]: ...
566+
567+
@overload
568+
@classmethod
569+
def create_factory(
570+
cls,
571+
model: type[U],
560572
bases: tuple[type[BaseFactory[Any]], ...] | None = None,
561573
**kwargs: Any,
562-
) -> type[F]:
574+
) -> type[BaseFactory[U]]: ...
575+
576+
@classmethod
577+
def create_factory(
578+
cls,
579+
model: type[U] | None = None,
580+
bases: tuple[type[BaseFactory[Any]], ...] | None = None,
581+
**kwargs: Any,
582+
) -> type[Self | BaseFactory[U]]:
563583
"""Generate a factory for the given type dynamically.
564584
565585
:param model: A type to model. Defaults to current factory __model__ if any.
@@ -572,12 +592,13 @@ def create_factory(
572592
"""
573593
if model is None:
574594
try:
575-
model = cls.__model__
595+
model = cls.__model__ # pyright: ignore[reportAssignmentType]
576596
except AttributeError as ex:
577597
msg = "A 'model' argument is required when creating a new factory from a base one"
578598
raise TypeError(msg) from ex
599+
579600
return cast(
580-
"Type[F]",
601+
"Type[Self]",
581602
type(
582603
f"{model.__name__}Factory", # pyright: ignore[reportOptionalMemberAccess]
583604
(*(bases or ()), cls),
@@ -1081,7 +1102,7 @@ def process_kwargs_coverage(cls, **kwargs: Any) -> abc.Iterable[dict[str, Any]]:
10811102
yield resolved
10821103

10831104
@classmethod
1084-
def build(cls, **kwargs: Any) -> T:
1105+
def build(cls, *_: Any, **kwargs: Any) -> T:
10851106
"""Build an instance of the factory's __model__
10861107
10871108
:param kwargs: Any kwargs. If field names are set in kwargs, their values will be used.

polyfactory/factories/pydantic_factory.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,9 @@
9797

9898
from typing_extensions import NotRequired, TypeGuard
9999

100-
T = TypeVar("T", bound="BaseModelV1 | BaseModelV2") # pyright: ignore[reportInvalidTypeForm]
100+
from pydantic import BaseModel
101+
102+
T = TypeVar("T", bound="BaseModel")
101103

102104
_IS_PYDANTIC_V1 = VERSION.startswith("1")
103105

@@ -465,7 +467,8 @@ def build(
465467

466468
if "_build_context" not in kwargs:
467469
kwargs["_build_context"] = PydanticBuildContext(
468-
seen_models=set(), factory_use_construct=factory_use_construct
470+
seen_models=set(),
471+
factory_use_construct=factory_use_construct,
469472
)
470473

471474
processed_kwargs = cls.process_kwargs(**kwargs)
@@ -502,8 +505,8 @@ def _create_model(cls, _build_context: PydanticBuildContext, **kwargs: Any) -> T
502505
if cls._get_build_context(_build_context).get("factory_use_construct"):
503506
if _is_pydantic_v1_model(cls.__model__):
504507
return cls.__model__.construct(**kwargs) # type: ignore[return-value]
505-
return cls.__model__.model_construct(**kwargs) # type: ignore[return-value]
506-
return cls.__model__(**kwargs) # type: ignore[return-value]
508+
return cls.__model__.model_construct(**kwargs)
509+
return cls.__model__(**kwargs)
507510

508511
@classmethod
509512
def coverage(cls, factory_use_construct: bool = False, **kwargs: Any) -> abc.Iterator[T]:

tests/test_new_types.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,6 @@ class A:
160160
""",
161161
)
162162

163-
factory = DataclassFactory.create_factory(module.A) # type: ignore[var-annotated]
163+
factory = DataclassFactory.create_factory(module.A)
164164
result = factory.build()
165165
assert result.field in {"a", "b"}

tests/test_pydantic_v1_v2.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def test_build_v1_with_contrained_fields() -> None:
4949
ConstrainedInt = Annotated[int, Field(ge=100, le=200)]
5050
ConstrainedStr = Annotated[str, Field(min_length=1, max_length=3)]
5151

52-
class Foo(pydantic.v1.BaseModel): # pyright: ignore[reportGeneralTypeIssues]
52+
class Foo(BaseModelV1):
5353
a: ConstrainedInt
5454
b: ConstrainedStr
5555
c: Union[ConstrainedInt, ConstrainedStr]
@@ -58,7 +58,7 @@ class Foo(pydantic.v1.BaseModel): # pyright: ignore[reportGeneralTypeIssues]
5858
f: List[ConstrainedInt]
5959
g: Dict[ConstrainedInt, ConstrainedStr]
6060

61-
ModelFactory.create_factory(Foo).build() # type: ignore[type-var]
61+
ModelFactory.create_factory(Foo).build()
6262

6363

6464
def test_build_v2_with_contrained_fields() -> None:

tests/test_recursive_models.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ class PydanticNode(BaseModel):
5353

5454
@pytest.mark.parametrize("factory_use_construct", (True, False))
5555
def test_recursive_pydantic_models(factory_use_construct: bool) -> None:
56-
factory = ModelFactory[PydanticNode].create_factory(PydanticNode)
56+
factory = ModelFactory.create_factory(PydanticNode)
5757

5858
result = factory.build(factory_use_construct)
5959
assert result.child is _Sentinel, "Default is not used" # type: ignore[comparison-overlap]

0 commit comments

Comments
 (0)