Skip to content

Commit

Permalink
Solve the issue pallets-eco#1312
Browse files Browse the repository at this point in the history
1. Provide a descriptor `ModelGetter` for solving the `db.Model` type from the `db` type dynamically.
2. Add `t.Type[sa_orm.MappedAsDataclass]` to `_FSA_MCT`.
3. Let `SQLAlchemy(...)` annotated by the provided `model_class` type.
  • Loading branch information
cainmagi committed Mar 27, 2024
1 parent fec440f commit dcc2a23
Showing 1 changed file with 167 additions and 41 deletions.
208 changes: 167 additions & 41 deletions src/flask_sqlalchemy/extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@
import sqlalchemy.event as sa_event
import sqlalchemy.exc as sa_exc
import sqlalchemy.orm as sa_orm
import typing_extensions as te
from flask import abort
from flask import current_app
from flask import Flask
from flask import has_app_context
from sqlalchemy.util import typing as compat_typing

from .model import _QueryProperty
from .model import BindMixin
Expand All @@ -32,22 +34,132 @@


# Type accepted for model_class argument
_FSA_MCT = t.TypeVar(
"_FSA_MCT",
bound=t.Union[
t.Type[Model],
sa_orm.DeclarativeMeta,
t.Type[sa_orm.DeclarativeBase],
t.Type[sa_orm.DeclarativeBaseNoMeta],
],
)
_FSA_MCT = t.Union[
t.Type[Model],
sa_orm.DeclarativeMeta,
t.Type[sa_orm.DeclarativeBase],
t.Type[sa_orm.DeclarativeBaseNoMeta],
t.Type[sa_orm.MappedAsDataclass],
]
_FSA_MCT_T = t.TypeVar("_FSA_MCT_T", bound=_FSA_MCT, covariant=True)


# Type returned by make_declarative_base
class _FSAModel(Model):
metadata: sa.MetaData


if t.TYPE_CHECKING:

class _FSAModel_KW(_FSAModel):
def __init__(self, **kw: t.Any) -> None:
...

else:
# To minimize side effects, the type hint only works for static type checker.
# At run time, `_FSAModel_KW` falls back to `_FSAModel`
_FSAModel_KW = _FSAModel


if t.TYPE_CHECKING:

@compat_typing.dataclass_transform(
field_specifiers=(
sa_orm.MappedColumn,
sa_orm.RelationshipProperty,
sa_orm.Composite,
sa_orm.Synonym,
sa_orm.mapped_column,
sa_orm.relationship,
sa_orm.composite,
sa_orm.synonym,
sa_orm.deferred,
),
)
class _FSAModel_DataClass(_FSAModel):
...

else:
# To minimize side effects, the type hint only works for static type checker.
# At run time, `_FSAModel_DataClass` falls back to `_FSAModel`
_FSAModel_DataClass = _FSAModel


class ModelGetter:
"""Model getter for the ``SQLAlchemy().Model`` property.
This getter is used for determining the correct type of ``SQLAlchemy().Model``.
When ``SQLAlchemy`` is initialized by
.. code-block:: python
db = SQLAlchemy(model_class=MappedAsDataclass)
the ``db.Model`` property needs to be a class decorated by ``dataclass_transform``.
Otherwise, the ``db.Model`` property needs to provide a synthesized initialization
method accepting unknown keyword arguments. These keyword arguments are not
annotated but limited in the range of data items. This rule is guaranteed by the
featuers of all other candidates of ``model_class``.
Calling the class property ``SQLAlchemy.Model`` will return this descriptor
directly.
"""

# This variant is at first. Its priority is highest for making SQLAlchemy[Any]
# exports a Model with type[_FSAModel_KW].
# Note that in actual using cases, users do not need to inherit Model classes.
@te.overload
def __get__(
self, obj: SQLAlchemy[t.Type[Model]], obj_cls: t.Any = None
) -> t.Type[_FSAModel_KW]:
...

# This variant needs to be prior than DeclarativeBase, because a class may inherit
# multiple classes. When both MappedAsDataclass and DeclarativeBase are in the MRO
# list, this configuration make type[_FSAModel_DataClass] preferred.
@te.overload
def __get__(
self, obj: SQLAlchemy[t.Type[sa_orm.MappedAsDataclass]], obj_cls: t.Any = None
) -> t.Type[_FSAModel_DataClass]:
...

@te.overload
def __get__(
self, obj: SQLAlchemy[t.Type[sa_orm.DeclarativeBase]], obj_cls: t.Any = None
) -> t.Type[_FSAModel_KW]:
...

@te.overload
def __get__(
self,
obj: SQLAlchemy[t.Type[sa_orm.DeclarativeBaseNoMeta]],
obj_cls: t.Any = None,
) -> t.Type[_FSAModel_KW]:
...

@te.overload
def __get__(
self, obj: SQLAlchemy[sa_orm.DeclarativeMeta], obj_cls: t.Any = None
) -> t.Type[_FSAModel_KW]:
...

@te.overload
def __get__(
self: te.Self, obj: None, obj_cls: t.Optional[t.Type[SQLAlchemy[t.Any]]] = None
) -> t.Type[_FSAModel]:
...

def __get__(
self: te.Self, obj: t.Optional[SQLAlchemy[t.Any]], obj_cls: t.Any = None
) -> t.Union[te.Self, t.Type[Model], t.Type[t.Any]]:
if isinstance(obj, SQLAlchemy):
return obj._Model
else:
return self


def _get_2x_declarative_bases(
model_class: _FSA_MCT,
) -> list[t.Type[t.Union[sa_orm.DeclarativeBase, sa_orm.DeclarativeBaseNoMeta]]]:
Expand All @@ -58,7 +170,7 @@ def _get_2x_declarative_bases(
]


class SQLAlchemy:
class SQLAlchemy(t.Generic[_FSA_MCT_T]):
"""Integrates SQLAlchemy with Flask. This handles setting up one or more engines,
associating tables and models with specific engines, and cleaning up connections and
sessions after each request.
Expand Down Expand Up @@ -168,7 +280,7 @@ def __init__(
metadata: sa.MetaData | None = None,
session_options: dict[str, t.Any] | None = None,
query_class: type[Query] = Query,
model_class: _FSA_MCT = Model, # type: ignore[assignment]
model_class: _FSA_MCT_T = Model, # type: ignore[assignment]
engine_options: dict[str, t.Any] | None = None,
add_models_to_shell: bool = True,
disable_autonaming: bool = False,
Expand Down Expand Up @@ -241,29 +353,17 @@ def __init__(
This is a subclass of SQLAlchemy's ``Table`` rather than a function.
"""

self.Model = self._make_declarative_base(
self._Model = self._make_declarative_base(
model_class, disable_autonaming=disable_autonaming
)
"""A SQLAlchemy declarative model class. Subclass this to define database
models.
If a model does not set ``__tablename__``, it will be generated by converting
the class name from ``CamelCase`` to ``snake_case``. It will not be generated
if the model looks like it uses single-table inheritance.
If a model or parent class sets ``__bind_key__``, it will use that metadata and
database engine. Otherwise, it will use the default :attr:`metadata` and
:attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``.
For code using the SQLAlchemy 1.x API, customize this model by subclassing
:class:`.Model` and passing the ``model_class`` parameter to the extension.
A fully created declarative model class can be
passed as well, to use a custom metaclass.
For code using the SQLAlchemy 2.x API, customize this model by subclassing
:class:`sqlalchemy.orm.DeclarativeBase` or
:class:`sqlalchemy.orm.DeclarativeBaseNoMeta`
and passing the ``model_class`` parameter to the extension.
"""A SQLAlchemy declarative model class. This private model class is returned
by ``_make_declarative_base``.
At run time, this class is the same as ``SQLAlchemy.Model``. Accessing
``SQLAlchemy.Model`` rather than this class is more recommended because
``SQLAlchemy.Model`` can provide better type hints.
:meta private:
"""

if engine_options is None:
Expand All @@ -277,6 +377,31 @@ def __init__(
if app is not None:
self.init_app(app)

# Need to be placed after __init__ because __init__ takes a default value
# named `Model`.
Model = ModelGetter()
"""A SQLAlchemy declarative model class. Subclass this to define database
models.
If a model does not set ``__tablename__``, it will be generated by converting
the class name from ``CamelCase`` to ``snake_case``. It will not be generated
if the model looks like it uses single-table inheritance.
If a model or parent class sets ``__bind_key__``, it will use that metadata and
database engine. Otherwise, it will use the default :attr:`metadata` and
:attr:`engine`. This is ignored if the model sets ``metadata`` or ``__table__``.
For code using the SQLAlchemy 1.x API, customize this model by subclassing
:class:`.Model` and passing the ``model_class`` parameter to the extension.
A fully created declarative model class can be
passed as well, to use a custom metaclass.
For code using the SQLAlchemy 2.x API, customize this model by subclassing
:class:`sqlalchemy.orm.DeclarativeBase` or
:class:`sqlalchemy.orm.DeclarativeBaseNoMeta`
and passing the ``model_class`` parameter to the extension.
"""

def __repr__(self) -> str:
if not has_app_context():
return f"<{type(self).__name__}>"
Expand Down Expand Up @@ -503,9 +628,7 @@ def __new__(
return Table

def _make_declarative_base(
self,
model_class: _FSA_MCT,
disable_autonaming: bool = False,
self, model_class: _FSA_MCT, disable_autonaming: bool = False
) -> t.Type[_FSAModel]:
"""Create a SQLAlchemy declarative model class. The result is available as
:attr:`Model`.
Expand Down Expand Up @@ -534,7 +657,7 @@ def _make_declarative_base(
``model`` can be an already created declarative model class.
"""
model: t.Type[_FSAModel]
declarative_bases = _get_2x_declarative_bases(model_class)
declarative_bases = _get_2x_declarative_bases(t.cast(t.Any, model_class))
if len(declarative_bases) > 1:
# raise error if more than one declarative base is found
raise ValueError(
Expand All @@ -547,11 +670,14 @@ def _make_declarative_base(
mixin_classes = [BindMixin, NameMixin, Model]
if disable_autonaming:
mixin_classes.remove(NameMixin)
model = types.new_class(
"FlaskSQLAlchemyBase",
(*mixin_classes, *model_class.__bases__),
{"metaclass": type(declarative_bases[0])},
lambda ns: ns.update(body),
model = t.cast(
t.Type[_FSAModel],
types.new_class(
"FlaskSQLAlchemyBase",
(*mixin_classes, *model_class.__bases__),
{"metaclass": type(declarative_bases[0])},
lambda ns: ns.update(body),
),
)
elif not isinstance(model_class, sa_orm.DeclarativeMeta):
metadata = self._make_metadata(None)
Expand Down

0 comments on commit dcc2a23

Please sign in to comment.