diff --git a/docs/advanced.rst b/docs/advanced.rst index f04535ec..f7240d3e 100644 --- a/docs/advanced.rst +++ b/docs/advanced.rst @@ -193,9 +193,34 @@ About Queryset Methods * :meth:`~django.db.models.query.QuerySet.distinct` works as expected. It only regards the fields of the base class, but this should never make a difference. -* :meth:`~django.db.models.query.QuerySet.select_related` works just as usual, but it can not - (yet) be used to select relations in inherited models (like - ``ModelA.objects.select_related('ModelC___fieldxy')`` ) +* :meth:`~django.db.models.query.QuerySet.select_related` works just as usual with the + exception that the query set must be derived from a PolymorphicRelatedQuerySetMixin + or PolymorphicRelatedQuerySet. + + This can be achieved by using a custom manager + + class NonPolyModel(models.Model): + relation = models.ForeignKey(BasePolyModel, on_delete=models.CASCADE) + objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + + or by converting a models queryset using + + class NonPolyModel(models.Model): + relation = models.ForeignKey(BasePolyModel, on_delete=models.CASCADE) + objects = models.Manager.from_queryset(QuerySet)() + + ``convert_to_polymorphic_queryset(NonPolyModel.objects).filter(...)`` + + To select related fields the model name comes after the field name and set the + field. + ``ModelA.objects.filter(....).select_related('field___TargetModel__subfield')``. + or using the polymorphic added related fieldname which is normally the lowercase + version of the model name. + ``ModelA.objects.filter(....).select_related('field__targetmodel__subfield')`` + + This automatically manages the via models between the model specified in the related + field and the target model. + ``ModelA.objects.filter(....).select_related('field__targetparentmodel__targetmodel__subfield')`` * :meth:`~django.db.models.query.QuerySet.extra` works as expected (it returns polymorphic results) but currently has one restriction: The resulting objects are required to have a unique diff --git a/src/polymorphic/query.py b/src/polymorphic/query.py index 8d582297..890633ea 100644 --- a/src/polymorphic/query.py +++ b/src/polymorphic/query.py @@ -3,15 +3,21 @@ """ import copy +import functools +import operator from collections import defaultdict from django.contrib.contenttypes.models import ContentType -from django.core.exceptions import FieldDoesNotExist +from django.core.exceptions import FieldDoesNotExist, FieldError from django.db import connections, models -from django.db.models import FilteredRelation -from django.db.models.query import ModelIterable, Q, QuerySet +from django.db.models import FilteredRelation, Manager +from django.db.models.constants import LOOKUP_SEP +from django.db.models.query import ModelIterable, Q, QuerySet, RelatedPopulator from .query_translate import ( + _create_base_path, + _get_all_sub_models, + _get_query_related_name, translate_polymorphic_field_path, translate_polymorphic_filter_definitions_in_args, translate_polymorphic_filter_definitions_in_kwargs, @@ -25,31 +31,273 @@ """ -class PolymorphicModelIterable(ModelIterable): - """ - ModelIterable for PolymorphicModel +def merge_dicts(primary, secondary): + """Deep merge two dicts - Yields real instances if qs.polymorphic_disabled is False, - otherwise acts like a regular ModelIterable. + Items from the primary dict are preserved in preference to those on the + secondary dict""" + + for k, v in secondary.items(): + if k in primary: + primary[k] = merge_dicts(primary[k], v) + else: + primary[k] = copy.deepcopy(v) + return primary + + +def search_object_cache(obj, source_model, target_model): + for search_part in _create_base_path(source_model, target_model).split("__"): + try: + obj = obj._state.fields_cache[search_part] + except KeyError: + return + return obj + + +class VanillaRelatedPopulator(RelatedPopulator): + def __init__(self, klass_info, select, db): + super().__init__(klass_info, select, db) + self.field = klass_info["field"] + self.reverse = klass_info["reverse"] + # replace replated populator with possibly a polymorphic version + # this is needed for relation across a non poly model + self.related_populators = get_related_populators(klass_info, select, self.db) + + def build_related(self, row, from_obj, *_): + self.populate(row, from_obj) + + +class RelatedPolymorphicPopulator: + """ + RelatedPopulator is used for select_related() object instantiation. + The idea is that each select_related() model will be populated by a + different RelatedPopulator instance. The RelatedPopulator instances get + klass_info and select (computed in SQLCompiler) plus the used db as + input for initialization. That data is used to compute which columns + to use, how to instantiate the model, and how to populate the links + between the objects. + The actual creation of the objects is done in populate() method. This + method gets row and from_obj as input and populates the select_related() + model instance. """ - def __iter__(self): - base_iter = super().__iter__() - if self.queryset.polymorphic_disabled: - return base_iter - return self._polymorphic_iterator(base_iter) + def __init__(self, klass_info, select, db): + self.db = db + # Pre-compute needed attributes. The attributes are: + # - model_cls: the possibly deferred model class to instantiate + # - either: + # - cols_start, cols_end: usually the columns in the row are + # in the same order model_cls.__init__ expects them, so we + # can instantiate by model_cls(*row[cols_start:cols_end]) + # - reorder_for_init: When select_related descends to a child + # class, then we want to reuse the already selected parent + # data. However, in this case the parent data isn't necessarily + # in the same order that Model.__init__ expects it to be, so + # we have to reorder the parent data. The reorder_for_init + # attribute contains a function used to reorder the field data + # in the order __init__ expects it. + # - pk_idx: the index of the primary key field in the reordered + # model data. Used to check if a related object exists at all. + # - init_list: the field attnames fetched from the database. For + # deferred models this isn't the same as all attnames of the + # model's fields. + # - related_populators: a list of RelatedPopulator instances if + # select_related() descends to related models from this model. + # - local_setter, remote_setter: Methods to set cached values on + # the object being populated and on the remote object. Usually + # these are Field.set_cached_value() methods. + select_fields = klass_info["select_fields"] + from_parent = klass_info["from_parent"] + if not from_parent: + self.cols_start = select_fields[0] + self.cols_end = select_fields[-1] + 1 + self.init_list = [f[0].target.attname for f in select[self.cols_start : self.cols_end]] + self.reorder_for_init = None + else: + attname_indexes = {select[idx][0].target.attname: idx for idx in select_fields} + model_init_attnames = (f.attname for f in klass_info["model"]._meta.concrete_fields) + self.init_list = [ + attname for attname in model_init_attnames if attname in attname_indexes + ] + self.reorder_for_init = operator.itemgetter( + *[attname_indexes[attname] for attname in self.init_list] + ) - def _polymorphic_iterator(self, base_iter): - """ - Here we do the same as:: + self.model_cls = klass_info["model"] + self.pk_idx = self.init_list.index(self.model_cls._meta.pk.attname) + self.related_populators = get_related_populators(klass_info, select, self.db) + self.local_setter = klass_info["local_setter"] + self.remote_setter = klass_info["remote_setter"] + self.field = klass_info["field"] + self.reverse = klass_info["reverse"] + self.content_type_manager = ContentType.objects.db_manager(self.db) + self.model_class_id = self.content_type_manager.get_for_model( + self.model_cls, for_concrete_model=False + ).pk + self.concrete_model_class_id = self.content_type_manager.get_for_model( + self.model_cls, for_concrete_model=True + ).pk + + def build_related(self, row, from_obj, post_actions): + if self.reorder_for_init: + obj_data = self.reorder_for_init(row) + else: + obj_data = row[self.cols_start : self.cols_end] + + if obj_data[self.pk_idx] is None: + obj = None + else: + obj = self.model_cls.from_db(self.db, self.init_list, obj_data) + self.post_build_modify( + obj, + from_obj, + post_actions, + functools.partial(self._populate, row, from_obj, post_actions), + ) - real_results = queryset._get_real_instances(list(base_iter)) - for o in real_results: yield o + def _populate(self, row, from_obj, post_actions, obj): + for rel_iter in self.related_populators: + rel_iter.build_related(row, obj, post_actions) - but it requests the objects in chunks from the database, - with QuerySet.iterator(chunk_size) per chunk + self.local_setter(from_obj, obj) + if obj is not None: + self.remote_setter(obj, from_obj) + + def post_build_modify(self, base_object, from_obj, post_actions, populate_fn): + if not hasattr(base_object, "polymorphic_ctype_id"): + populate_fn(base_object) + elif base_object.polymorphic_ctype_id == self.model_class_id: + # Real class is exactly the same as base class, go straight to results + populate_fn(base_object) + else: + real_concrete_class = base_object.get_real_instance_class() + real_concrete_class_id = base_object.get_real_concrete_instance_class_id() + + if real_concrete_class_id is None: + # Dealing with a stale content type + populate_fn(None) + return False + elif real_concrete_class_id == self.concrete_model_class_id: + # Real and base classes share the same concrete ancestor, + # upcast it and put it in the results + populate_fn(transmogrify(real_concrete_class, base_object)) + return False + else: + # This model has a concrete derived class: either track it for bulk + # retrieval or if it is already fetched as part of a select_related + # enable pivoting to that object + real_concrete_class = self.content_type_manager.get_for_id( + real_concrete_class_id + ).model_class() + populate_fn(base_object) + post_actions.append( + ( + functools.partial( + self.pivot_onto_cached_subclass, + from_obj, + base_object, + real_concrete_class, + ), + populate_fn, + ) + ) + + def pivot_onto_cached_subclass(self, from_obj, obj, model_target_cls): + """Pivot to final polymorphic class. + + Pivot the object created from the base query onto the true polymorphic + instance, we need to ensure that this is only done on objects that are + from non parent-child type relationships. + + If we cannot pivot we return info to be used in the PolymorphicModelIterable + to ensure the correct model loaded from the additional bulk queries """ + original = obj + parents = model_target_cls()._get_inheritance_relation_fields_and_models() + for cls in reversed(model_target_cls.mro()[: -len(self.model_cls.mro())]): + for rel_iter in self.related_populators: + if not isinstance( + rel_iter, (VanillaRelatedPopulator, RelatedPolymorphicPopulator) + ): + # NOTE: We don't know how to handle this type of populator! + continue + if rel_iter.reverse and rel_iter.model_cls is cls: + if rel_iter.field.name in parents.keys(): + obj = getattr(obj, rel_iter.field.remote_field.name) + + if not isinstance(obj, model_target_cls): + # This allow pivoting of object that are descendants of the original field + if not original._meta.get_path_to_parent(from_obj._meta.model): + obj = search_object_cache(original, original._meta.model, model_target_cls) + + if isinstance(obj, model_target_cls): + # We only want to pivot onto a field from a different object, ie not a parent/child + # relationship as this will break the cache and other object + if not original._meta.get_path_to_parent(from_obj._meta.model): + self.local_setter(from_obj, obj) + if obj is not None: + self.remote_setter(obj, from_obj) + return None, None + + local_pk_name = original.__class__.polymorphic_primary_key_name + target_pk_name = original.__class__.polymorphic_primary_key_name + original_pk = getattr(original, local_pk_name) + + # NOTE: We could use a recursive function on model_target_cls._meta.parents + # PolymorphicModel.much _get_inheritance_relation_fields_and_models.like add_all_sub_models + for field in model_target_cls._meta.fields: + if field.is_relation is True: + for rel_field in field.foreign_related_fields: + if rel_field.name is local_pk_name and rel_field.model is original._meta.model: + target_pk_name = field.attname + + return model_target_cls, (original_pk, self.field.name, target_pk_name) + + +def get_related_populators(klass_info, select, db): + from .models import PolymorphicModel + + iterators = [] + related_klass_infos = klass_info.get("related_klass_infos", []) + for rel_klass_info in related_klass_infos: + model = rel_klass_info["model"] + rel_cls = VanillaRelatedPopulator(rel_klass_info, select, db) + if issubclass(model, PolymorphicModel): + rel_cls = RelatedPolymorphicPopulator(rel_klass_info, select, db) + else: + for col, *_ in select: + if issubclass(col.target.model, PolymorphicModel): + rel_cls = RelatedPolymorphicPopulator(rel_klass_info, select, db) + break + iterators.append(rel_cls) + return iterators + + +class PolymorphicModelIterable(ModelIterable): + """ + ModelIterable for PolymorphicModel + + Yields real instances if qs.polymorphic_disabled is False, + otherwise acts like a regular ModelIterable. We inherit from + ModelIterable non base BaseIterable even though we completely + replace it, but this allows Django test in Prefetch to work + """ + + def __iter__(self): + queryset = self.queryset + db = queryset.db + compiler = queryset.query.get_compiler(using=db) + # Execute the query. This will also fill compiler.select, klass_info, + # and annotations. + results = compiler.execute_sql( + chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size + ) + select, klass_info, annotation_col_map = ( + compiler.select, + compiler.klass_info, + compiler.annotation_col_map, + ) # some databases have a limit on the number of query parameters, we must # respect this for generating get_real_instances queries because those # queries do a large WHERE IN clause with primary keys @@ -64,24 +312,161 @@ def _polymorphic_iterator(self, base_iter): sql_chunk = sql_chunk or Polymorphic_QuerySet_objects_per_request + model_cls = klass_info["model"] + select_fields = klass_info["select_fields"] + model_fields_start, model_fields_end = select_fields[0], select_fields[-1] + 1 + init_list = [f[0].target.attname for f in select[model_fields_start:model_fields_end]] + related_populators = get_related_populators(klass_info, select, db) + known_related_objects = [ + ( + field, + related_objs, + operator.attrgetter( + *[ + ( + field.attname + if from_field == "self" + else queryset.model._meta.get_field(from_field).attname + ) + for from_field in field.from_fields + ] + ), + ) + for field, related_objs in queryset._known_related_objects.items() + ] + base_iter = compiler.results_iter(results) while True: + result_objects = [] base_result_objects = [] reached_end = False # Fetch in chunks - for _ in range(sql_chunk): + post_actions = list() + for i in range(sql_chunk): + # dict contains one entry per unique model type occurring in result, + # in the format idlist_per_model[modelclass]=[list-of-object-ids] try: - o = next(base_iter) - base_result_objects.append(o) + row = next(base_iter) + obj = model_cls.from_db( + db, init_list, row[model_fields_start:model_fields_end] + ) + for rel_populator in related_populators: + rel_populator.build_related(row, obj, post_actions) + base_result_objects.append([row, obj]) except StopIteration: reached_end = True break - yield from self.queryset._get_real_instances(base_result_objects) + if not self.queryset.polymorphic_disabled: + self.fetch_polymorphic(post_actions, base_result_objects) + + for row, obj in base_result_objects: + if annotation_col_map: + for attr_name, col_pos in annotation_col_map.items(): + setattr(obj, attr_name, row[col_pos]) + + # Add the known related objects to the model. + for field, rel_objs, rel_getter in known_related_objects: + # Avoid overwriting objects loaded by, e.g., select_related(). + if field.is_cached(obj): + continue + rel_obj_id = rel_getter(obj) + try: + rel_obj = rel_objs[rel_obj_id] + except KeyError: + pass # May happen in qs1 | qs2 scenarios. + else: + setattr(obj, field.name, rel_obj) + result_objects.append(obj) + + if not self.queryset.polymorphic_disabled: + result_objects = self.queryset._get_real_instances(result_objects) + + for o in result_objects: + yield o if reached_end: return + def apply_select_related(self, qs, relations): + if self.queryset.query.select_related is True: + return qs.select_related() + + model_name = qs.model.__name__.lower() + if isinstance(self.queryset.query.select_related, dict): + select_related = {} + if isinstance(qs.query.select_related, dict): + select_related = qs.query.select_related + for k, v in self.queryset.query.select_related.items(): + if k in relations: + if not isinstance(select_related, dict): + select_related = {} + if isinstance(v, dict): + if model_name in v: + select_related = merge_dicts(select_related, v[model_name]) + else: + for field in qs.model._meta.fields: + if field.name in v: + select_related = merge_dicts(select_related, v[field.name]) + else: + select_related = merge_dicts(select_related, v) + qs.query.select_related = select_related + return qs + + def fetch_polymorphic(self, post_actions, base_result_objects): + update_fn_per_model = defaultdict(list) + idlist_per_model = defaultdict(list) + + for action, populate_fn in post_actions: + target_class, pk_info = action() + if target_class: + pk, name, pk_name = pk_info + idlist_per_model[target_class].append(pk_info) + update_fn_per_model[target_class].append((populate_fn, pk)) + + # For each model in "idlist_per_model" request its objects (the real model) + # from the db and store them in results[]. + # Then we copy the annotate fields from the base objects to the real objects. + # Then we copy the extra() select fields from the base objects to the real objects. + # TODO: defer(), only(): support for these would be around here + for real_concrete_class, data in idlist_per_model.items(): + idlist, names, pk_attr_names = zip(*data) + updates = update_fn_per_model[real_concrete_class] + + if len(set(pk_attr_names)) != 1: + raise FieldError( + "PolymorphicModel: cannot convert model type as non " + f"upk_namesnique related key names {pk_attr_names}" + ) + + pk_attr_name = pk_attr_names[0] + # FIXME: this seams to get extra field already fetch in base + # initial query, we may need to add defer? + + real_objects = real_concrete_class._base_objects.db_manager(self.queryset.db).filter( + **{("%s__in" % pk_attr_name): idlist}, + ) + + real_objects = self.apply_select_related(real_objects, set(names)) + real_objects_dict = { + getattr(real_object, pk_attr_name): real_object for real_object in real_objects + } + + for populate_fn, o_pk in updates: + real_object = real_objects_dict.get(o_pk) + if real_object is None: + continue + + # need shallow copy to avoid duplication in caches (see PR #353) + real_object = copy.copy(real_object) + real_class = real_object.get_real_instance_class() + + # If the real class is a proxy, upcast it + if real_class != real_concrete_class: + real_object = transmogrify(real_class, real_object) + + populate_fn(real_object) + def transmogrify(cls, obj): """ @@ -103,7 +488,75 @@ def transmogrify(cls, obj): # PolymorphicQuerySet -class PolymorphicQuerySet(QuerySet): +class PolymorphicQuerySetMixin(QuerySet): + def select_related(self, *fields): + if fields == (None,) or not len(fields): + return super().select_related(*fields) + field_with_poly = set(self.convert_related_fieldnames(fields)) + return super().select_related(*sorted(list(field_with_poly))) + + def _convert_field_name_part(self, field_parts, model): + """ + recursively convert a fieldname into (model, filedname) + """ + field = None + part = field_parts[0] + next_parts = field_parts[1:] + field_path = [] + rel_model = None + try: + field = model._meta.get_field(part) + field_path = [part] + yield field_path + + if field.is_relation: + rel_model = field.related_model + if next_parts: + child_selectors = self._convert_field_name_part(next_parts, rel_model) + for selector in child_selectors: + yield field_path + selector + else: + rel_model = model + except FieldDoesNotExist: + submodels = _get_all_sub_models(model) + if part == "*": + for rel_model in submodels.values(): + if model is rel_model: + continue + yield from self._convert_submodel_fields_parts(next_parts, model, rel_model) + else: + rel_model = submodels.get(part, None) + if model is not rel_model: + yield from self._convert_submodel_fields_parts(next_parts, model, rel_model) + else: + raise + + def _convert_submodel_fields_parts(self, field_parts, model, rel_model): + field_path = list(_create_base_path(model, rel_model).split("__")) + for field_part_idx in range(0, len(field_path)): + yield field_path[0 : 1 + field_part_idx] + yield field_path + if field_parts: + child_selectors = self._convert_field_name_part(field_parts, rel_model) + for selector in child_selectors: + yield field_path + selector + + def convert_related_fieldnames(self, fields, opts=None): + """ + convert the field name which may contain polymorphic models names into + raw filed names that can be used with django select_related and + prefetch_related. + """ + if not opts: + opts = self.model + for field_name in fields: + field_parts = field_name.split(LOOKUP_SEP) + selectors = self._convert_field_name_part(field_parts, opts) + for selector in selectors: + yield "__".join(selector) + + +class PolymorphicQuerySet(PolymorphicQuerySetMixin, QuerySet): """ QuerySet for PolymorphicModel @@ -191,9 +644,9 @@ def _filter_or_exclude(self, negate, args, kwargs): def order_by(self, *field_names): """translate the field paths in the args, then call vanilla order_by.""" field_names = [ - translate_polymorphic_field_path(self.model, a) - if isinstance(a, str) - else a # allow expressions to pass unchanged + ( + translate_polymorphic_field_path(self.model, a) if isinstance(a, str) else a + ) # allow expressions to pass unchanged for a in field_names ] return super().order_by(*field_names) @@ -291,7 +744,7 @@ def tree_node_test___lookup(my_model, node): for i in range(len(node.children)): child = node.children[i] - if type(child) is tuple: + if isinstance(child, tuple): # this Q object child is a tuple => a kwarg like Q( instance_of=ModelB ) assert "___" not in child[0], ___lookup_assert_msg else: @@ -412,22 +865,37 @@ class self.model, but as a class derived from self.model. We want to re-fetch real_concrete_class = content_type_manager.get_for_id( real_concrete_class_id ).model_class() - idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name)) - indexlist_per_model[real_concrete_class].append((i, len(resultlist))) - resultlist.append(None) + + cached_obj = search_object_cache(base_object, self.model, real_concrete_class) + if cached_obj: + resultlist.append(cached_obj) + else: + idlist_per_model[real_concrete_class].append(getattr(base_object, pk_name)) + indexlist_per_model[real_concrete_class].append((i, len(resultlist))) + resultlist.append(None) # For each model in "idlist_per_model" request its objects (the real model) # from the db and store them in results[]. # Then we copy the annotate fields from the base objects to the real objects. # Then we copy the extra() select fields from the base objects to the real objects. # TODO: defer(), only(): support for these would be around here + # Also see PolymorphicModelIterable.fetch_polymorphic + + filter_relations = [ + _get_query_related_name(mdl_cls) + for mdl_cls in _get_all_sub_models(self.model).values() + ] + for real_concrete_class, idlist in idlist_per_model.items(): indices = indexlist_per_model[real_concrete_class] real_objects = real_concrete_class._base_objects.db_manager(self.db).filter( **{(f"{pk_name}__in"): idlist} ) # copy select related configuration to new qs - real_objects.query.select_related = self.query.select_related + current_relation = real_objects.model.__name__.lower() + real_objects = self.apply_select_related( + real_objects, current_relation, filter_relations + ) # Copy deferred fields configuration to the new queryset deferred_loading_fields = [] @@ -516,6 +984,37 @@ class self.model, but as a class derived from self.model. We want to re-fetch return resultlist + def apply_select_related(self, qs, relation, filtered): + if self.query.select_related is True: + return qs.select_related() + + model_name = qs.model.__name__.lower() + if isinstance(self.query.select_related, dict): + select_related = {} + if isinstance(qs.query.select_related, dict): + select_related = qs.query.select_related + for k, v in self.query.select_related.items(): + if k in filtered and k != relation: + continue + else: + if not isinstance(select_related, dict): + select_related = {} + if k == relation: + if isinstance(v, dict): + if model_name in v: + select_related = merge_dicts(select_related, v[model_name]) + else: + for field in qs.model._meta.fields: + if field.name in v: + select_related = merge_dicts(select_related, v[field.name]) + else: + select_related = merge_dicts(select_related, v) + else: + select_related[k] = v + + qs.query.select_related = select_related + return qs + def __repr__(self, *args, **kwargs): if self.model.polymorphic_query_multiline_output: result = ",\n ".join(repr(o) for o in self.all()) @@ -557,3 +1056,60 @@ def delete(self): disrupts the model hierarchy/relationship traversal. """ return QuerySet.delete(self.non_polymorphic()) + + +################################################################################### +# PolymorphicRelatedQuerySet + + +class PolymorphicRelatedQuerySetMixin(PolymorphicQuerySetMixin): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._iterable_class = PolymorphicModelIterable + self.polymorphic_disabled = False + + def _clone(self, *args, **kwargs): + # Django's _clone only copies its own variables, so we need to copy ours here + new = super()._clone(*args, **kwargs) + new.polymorphic_disabled = self.polymorphic_disabled + return new + + def _get_real_instances(self, base_result_objects): + return base_result_objects + + +class PolymorphicRelatedQuerySet(PolymorphicRelatedQuerySetMixin, QuerySet): + pass + + +def convert_to_polymorphic_queryset(qs): + "Convert a queryset to one that support polymorphic evaluation" + + if isinstance(qs, Manager): + qs = qs.get_queryset() + + if issubclass(qs.__class__, PolymorphicQuerySetMixin): + return qs + + assert issubclass(QuerySet, qs.__class__), ( + f"PolymorphicModel: cannot guarantee conversion of {qs.__class__} to polymorphic queryset" + ) + + class RelatedPolyQuerySet(PolymorphicRelatedQuerySetMixin, qs.__class__): + @classmethod + def _convert_to(cls, qs): + c = cls( + model=qs.model, + query=qs.query.chain(), + using=qs._db, + hints=qs._hints, + ) + c._sticky_filter = qs._sticky_filter + c._for_write = qs._for_write + c._prefetch_related_lookups = qs._prefetch_related_lookups[:] + c._known_related_objects = qs._known_related_objects + c._fields = qs._fields + return c + + poly_qs = RelatedPolyQuerySet._convert_to(qs) + return poly_qs diff --git a/src/polymorphic/showfields.py b/src/polymorphic/showfields.py index e894c70b..082caf6b 100644 --- a/src/polymorphic/showfields.py +++ b/src/polymorphic/showfields.py @@ -62,7 +62,7 @@ def _showfields_add_regular_fields(self, parts): out = field.name # if this is the standard primary key named "id", print it as we did with older versions of django_polymorphic - if field.primary_key and field.name == "id" and type(field) is models.AutoField: + if field.primary_key and field.name == "id" and isinstance(field, models.AutoField): out += f" {getattr(self, field.name)}" # otherwise, display it just like all other fields (with correct type, shortened content etc.) diff --git a/src/polymorphic/tests/migrations/0001_initial.py b/src/polymorphic/tests/migrations/0001_initial.py index 99c9ef8e..08d4f187 100644 --- a/src/polymorphic/tests/migrations/0001_initial.py +++ b/src/polymorphic/tests/migrations/0001_initial.py @@ -1,4 +1,4 @@ -# Generated by Django 4.2 on 2025-12-21 10:54 +# Generated by Django 4.2 on 2025-12-23 08:21 from django.conf import settings from django.db import migrations, models @@ -267,6 +267,20 @@ class Migration(migrations.Migration): 'base_manager_name': 'objects', }, ), + migrations.CreateModel( + name='NonSymRelationBase', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('field_base', models.CharField(max_length=10)), + ('fk', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='relationbase_set', to='tests.nonsymrelationbase')), + ('m2m', models.ManyToManyField(to='tests.nonsymrelationbase')), + ('polymorphic_ctype', models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_%(app_label)s.%(class)s_set+', to='contenttypes.contenttype')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + ), migrations.CreateModel( name='NormalBase', fields=[ @@ -287,6 +301,18 @@ class Migration(migrations.Migration): 'base_manager_name': 'objects', }, ), + migrations.CreateModel( + name='ParentModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('name', models.CharField(max_length=10)), + ('polymorphic_ctype', models.ForeignKey(editable=False, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='polymorphic_%(app_label)s.%(class)s_set+', to='contenttypes.contenttype')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + ), migrations.CreateModel( name='Participant', fields=[ @@ -305,6 +331,13 @@ class Migration(migrations.Migration): ('field1', models.CharField(max_length=30)), ], ), + migrations.CreateModel( + name='PlainModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('relation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.parentmodel')), + ], + ), migrations.CreateModel( name='PlainParentModelWithManager', fields=[ @@ -387,6 +420,19 @@ class Migration(migrations.Migration): }, bases=(polymorphic.showfields.ShowFieldTypeAndContent, models.Model), ), + migrations.CreateModel( + name='AltChildModel', + fields=[ + ('parentmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.parentmodel')), + ('other_name', models.CharField(max_length=10)), + ('link_on_altchild', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='+', to='tests.plaina')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.parentmodel',), + ), migrations.CreateModel( name='BlogA', fields=[ @@ -710,6 +756,42 @@ class Migration(migrations.Migration): }, bases=('tests.proxybase',), ), + migrations.CreateModel( + name='NonSymRelationA', + fields=[ + ('nonsymrelationbase_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.nonsymrelationbase')), + ('field_a', models.CharField(max_length=10)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.nonsymrelationbase',), + ), + migrations.CreateModel( + name='NonSymRelationB', + fields=[ + ('nonsymrelationbase_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.nonsymrelationbase')), + ('field_b', models.CharField(max_length=10)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.nonsymrelationbase',), + ), + migrations.CreateModel( + name='NonSymRelationBC', + fields=[ + ('nonsymrelationbase_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.nonsymrelationbase')), + ('field_c', models.CharField(max_length=10)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.nonsymrelationbase',), + ), migrations.CreateModel( name='NormalExtension', fields=[ @@ -851,6 +933,13 @@ class Migration(migrations.Migration): }, bases=('tests.uuidproject',), ), + migrations.CreateModel( + name='VanillaPlainModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('relation', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.parentmodel')), + ], + ), migrations.CreateModel( name='SwappedModel', fields=[ @@ -893,6 +982,13 @@ class Migration(migrations.Migration): }, bases=(polymorphic.showfields.ShowFieldType, models.Model), ), + migrations.CreateModel( + name='RefPlainModel', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('plainobj', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='tests.plainmodel')), + ], + ), migrations.CreateModel( name='RecursionBug', fields=[ @@ -927,6 +1023,14 @@ class Migration(migrations.Migration): }, bases=(polymorphic.showfields.ShowFieldTypeAndContent, models.Model), ), + migrations.CreateModel( + name='PlainModelWithM2M', + fields=[ + ('id', models.AutoField(auto_created=True, primary_key=True, serialize=False, verbose_name='ID')), + ('field1', models.CharField(max_length=10)), + ('m2m', models.ManyToManyField(to='tests.parentmodel')), + ], + ), migrations.CreateModel( name='PlainChildModelWithManager', fields=[ @@ -1269,6 +1373,18 @@ class Migration(migrations.Migration): }, bases=('tests.subclassselectorproxybasemodel',), ), + migrations.CreateModel( + name='AltChildAsBaseModel', + fields=[ + ('altchildmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.altchildmodel')), + ('more_name', models.CharField(max_length=10)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.altchildmodel',), + ), migrations.CreateModel( name='Bottom', fields=[ @@ -1317,6 +1433,30 @@ class Migration(migrations.Migration): }, bases=('tests.mrobase2', 'tests.mrobase3'), ), + migrations.CreateModel( + name='NonAutoPKChild', + fields=[ + ('altchildmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, to='tests.altchildmodel')), + ('uuid_primary_key', models.UUIDField(default=uuid.uuid1, primary_key=True, serialize=False)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.altchildmodel',), + ), + migrations.CreateModel( + name='NonUUIDArtProject', + fields=[ + ('uuidresearchproject_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, to='tests.uuidresearchproject')), + ('idkey', models.AutoField(primary_key=True, serialize=False)), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.uuidresearchproject',), + ), migrations.CreateModel( name='PlainC', fields=[ @@ -1437,6 +1577,19 @@ class Migration(migrations.Migration): }, bases=('tests.inlinemodela',), ), + migrations.CreateModel( + name='ChildModel', + fields=[ + ('parentmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.parentmodel')), + ('other_name', models.CharField(max_length=10)), + ('link_on_child', models.ForeignKey(null=True, on_delete=django.db.models.deletion.CASCADE, related_name='+', to='tests.modelextraexternal')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.parentmodel',), + ), migrations.CreateModel( name='BlogEntry', fields=[ @@ -1497,6 +1650,18 @@ class Migration(migrations.Migration): }, bases=('tests.uuidartprojecta',), ), + migrations.CreateModel( + name='AltChildWithM2MModel', + fields=[ + ('altchildmodel_ptr', models.OneToOneField(auto_created=True, on_delete=django.db.models.deletion.CASCADE, parent_link=True, primary_key=True, serialize=False, to='tests.altchildmodel')), + ('m2m', models.ManyToManyField(to='tests.plaina')), + ], + options={ + 'abstract': False, + 'base_manager_name': 'objects', + }, + bases=('tests.altchildmodel',), + ), migrations.CreateModel( name='UUIDArtProjectC', fields=[ diff --git a/src/polymorphic/tests/models.py b/src/polymorphic/tests/models.py index 1bb10380..9c43e458 100644 --- a/src/polymorphic/tests/models.py +++ b/src/polymorphic/tests/models.py @@ -10,7 +10,11 @@ from polymorphic.managers import PolymorphicManager from polymorphic.models import PolymorphicModel -from polymorphic.query import PolymorphicQuerySet +from polymorphic.query import ( + PolymorphicQuerySet, + PolymorphicRelatedQuerySetMixin, + PolymorphicRelatedQuerySet, +) from polymorphic.showfields import ShowFieldContent, ShowFieldType, ShowFieldTypeAndContent @@ -342,6 +346,10 @@ class UUIDResearchProject(UUIDProject): supervisor = models.CharField(max_length=30) +class NonUUIDArtProject(UUIDResearchProject): + idkey = models.AutoField(primary_key=True) + + class UUIDArtProjectA(UUIDArtProject): ... @@ -781,3 +789,78 @@ def __init__(self, *args, **kwargs): """ super().__init__(*args, **kwargs) self.old_status_id = self.status_id + + +class NonSymRelationBase(PolymorphicModel): + field_base = models.CharField(max_length=10) + fk = models.ForeignKey( + "self", on_delete=models.CASCADE, null=True, related_name="relationbase_set" + ) + m2m = models.ManyToManyField("self", symmetrical=False) + + +class NonSymRelationA(NonSymRelationBase): + field_a = models.CharField(max_length=10) + + +class NonSymRelationB(NonSymRelationBase): + field_b = models.CharField(max_length=10) + + +class NonSymRelationBC(NonSymRelationBase): + field_c = models.CharField(max_length=10) + + +class CustomPolySupportingQuerySet(PolymorphicRelatedQuerySetMixin, models.QuerySet): + pass + + +class ParentModel(PolymorphicModel): + name = models.CharField(max_length=10) + + +class ChildModel(ParentModel): + other_name = models.CharField(max_length=10) + link_on_child = models.ForeignKey( + ModelExtraExternal, on_delete=models.CASCADE, null=True, related_name="+" + ) + + +class AltChildModel(ParentModel): + other_name = models.CharField(max_length=10) + link_on_altchild = models.ForeignKey( + PlainA, on_delete=models.CASCADE, null=True, related_name="+" + ) + + +class AltChildAsBaseModel(AltChildModel): + more_name = models.CharField(max_length=10) + + +class NonAutoPKChild(AltChildModel): + uuid_primary_key = models.UUIDField(primary_key=True, default=uuid.uuid1) + + +class PlainModel(models.Model): + relation = models.ForeignKey(ParentModel, on_delete=models.CASCADE) + objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + + +class VanillaPlainModel(models.Model): + relation = models.ForeignKey(ParentModel, on_delete=models.CASCADE) + + +class RefPlainModel(models.Model): + plainobj = models.ForeignKey(PlainModel, on_delete=models.CASCADE) + objects = models.Manager.from_queryset(QuerySet)() + poly_objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + + +class PlainModelWithM2M(models.Model): + field1 = models.CharField(max_length=10) + m2m = models.ManyToManyField(ParentModel) + objects = models.Manager.from_queryset(PolymorphicRelatedQuerySet)() + + +class AltChildWithM2MModel(AltChildModel): + m2m = models.ManyToManyField(PlainA) diff --git a/src/polymorphic/tests/test_orm.py b/src/polymorphic/tests/test_orm.py index 412422cd..cb1735fb 100644 --- a/src/polymorphic/tests/test_orm.py +++ b/src/polymorphic/tests/test_orm.py @@ -1,18 +1,33 @@ import pytest import uuid - +from unittest import expectedFailure from django.contrib.auth import get_user_model from django.contrib.contenttypes.models import ContentType from django.db import models, connection -from django.db.models import Case, Count, FilteredRelation, Q, Sum, When, Exists, OuterRef +from django.db.models import ( + Case, + Count, + FilteredRelation, + Prefetch, + Q, + Sum, + When, + Exists, + OuterRef, +) from django.db.utils import IntegrityError, NotSupportedError + from django.test import TransactionTestCase from django.test.utils import CaptureQueriesContext from polymorphic import query_translate +from polymorphic.query import convert_to_polymorphic_queryset, PolymorphicRelatedQuerySetMixin from polymorphic.managers import PolymorphicManager from polymorphic.models import PolymorphicTypeInvalid, PolymorphicTypeUndefined from polymorphic.tests.models import ( + AltChildAsBaseModel, + AltChildModel, + AltChildWithM2MModel, ArtProject, Base, BlogA, @@ -20,9 +35,11 @@ BlogBase, BlogEntry, BlogEntry_limit_choices_to, + ChildModel, ChildModelWithManager, CustomPkBase, CustomPkInherit, + Duck, Enhance_Base, Enhance_Plain, Enhance_Inherit, @@ -57,15 +74,24 @@ MultiTableDerived, MyManager, MyManagerQuerySet, + NonAutoPKChild, NonPolymorphicParent, NonProxyChild, + NonSymRelationA, + NonSymRelationB, + NonSymRelationBase, + NonSymRelationBC, + NonUUIDArtProject, One2OneRelatingModel, One2OneRelatingModelDerived, + ParentModel, ParentModelWithManager, PlainA, PlainB, PlainC, PlainChildModelWithManager, + PlainModel, + PlainModelWithM2M, PlainMyManager, PlainMyManagerQuerySet, PlainParentModelWithManager, @@ -76,6 +102,7 @@ ProxyModelB, ProxyModelBase, RedheadDuck, + RefPlainModel, RelatingModel, RelationA, RelationB, @@ -97,7 +124,7 @@ UUIDPlainC, UUIDProject, UUIDResearchProject, - Duck, + VanillaPlainModel, PurpleHeadDuck, Account, SpecialAccount1, @@ -998,21 +1025,21 @@ def test_polymorphic__accessor_caching(self): blog_a = BlogA.objects.get(id=blog_a.id) # test reverse accessor & check that we get back cached object on repeated access - self.assertEqual(blog_base.bloga, blog_a) - self.assertIs(blog_base.bloga, blog_base.bloga) + assert blog_base.bloga == blog_a + assert blog_base.bloga is blog_base.bloga cached_blog_a = blog_base.bloga # test forward accessor & check that we get back cached object on repeated access - self.assertEqual(blog_a.blogbase_ptr, blog_base) - self.assertIs(blog_a.blogbase_ptr, blog_a.blogbase_ptr) + assert blog_a.blogbase_ptr == blog_base + assert blog_a.blogbase_ptr is blog_a.blogbase_ptr cached_blog_base = blog_a.blogbase_ptr # check that refresh_from_db correctly clears cached related objects blog_base.refresh_from_db() blog_a.refresh_from_db() - self.assertIsNot(cached_blog_a, blog_base.bloga) - self.assertIsNot(cached_blog_base, blog_a.blogbase_ptr) + assert cached_blog_a is not blog_base.bloga + assert cached_blog_base is not blog_a.blogbase_ptr def test_polymorphic__aggregate(self): """test ModelX___field syntax on aggregate (should work for annotate either)""" @@ -1967,3 +1994,875 @@ def test_infinite_recursion_with_only(self): RecursionBug.objects.filter(id=item.id).update(status=closed) item.refresh_from_db(fields=("status",)) assert item.status == closed + + def test_normal_django_to_poly_related_give_poly_type(self): + obj1 = ParentModel.objects.create(name="m1") + obj2 = ChildModel.objects.create(name="m2", other_name="m2") + obj3 = ChildModel.objects.create(name="m3") + obj4 = ChildModel.objects.create(name="m3") + obj5 = AltChildModel.objects.create(name="m4") + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + PlainModel.objects.create(relation=obj5) + + with self.assertNumQueries(10): + # Queries will be + # * 1 for All PlainModels object (1) + # * 1 for each relations ParentModel (5) + # * 1 for each relations ChildModel is needed (3) + # * 1 for each relations AltChildModel is needed (1) + multi_q = [ + # these obj.relation values will have their proper sub type + obj.relation + for obj in PlainModel.objects.all() + ] + multi_q_types = [type(obj) for obj in multi_q] + + with self.assertNumQueries(3): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + # 1 query for each relation type + obj.relation + for obj in PlainModel.objects.select_related("relation") + ] + grouped_q_types = [type(obj) for obj in grouped_q] + + self.assertListEqual(multi_q_types, grouped_q_types) + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4, obj5]) + + def test_normal_django_to_multi_level_poly_related_give_poly_type(self): + obj1 = ParentModel.objects.create(name="m1") + obj2 = ChildModel.objects.create(name="m2", other_name="c1") + obj3 = AltChildModel.objects.create(name="m3") + obj4 = AltChildAsBaseModel.objects.create(name="m4", more_name="acab1") + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + + with self.assertNumQueries(4): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + obj.relation + for obj in PlainModel.objects.select_related("relation") + ] + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4]) + + def test_related_fetch_of_different_type_pks(self): + "pk on child is not same field type as pk on parent and thus prt field" + obj1 = ChildModel.objects.create(name="m1", other_name="c1") + obj2 = ParentModel.objects.create(name="m2") + obj3 = NonAutoPKChild.objects.create(name="m3", other_name="napk1") + obj4 = AltChildModel.objects.create(name="m4", other_name="acm1") + obj5 = ChildModel.objects.create(name="m5", other_name="c3") + obj6 = AltChildAsBaseModel.objects.create(name="m6", more_name="acab1") + obj7 = NonAutoPKChild.objects.create(name="m7", other_name="napk2") + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + PlainModel.objects.create(relation=obj5) + PlainModel.objects.create(relation=obj6) + PlainModel.objects.create(relation=obj7) + + def object_info(obj): + return { + "pk": obj.pk, + "parentmodel_ptr": getattr(obj, "parentmodel_ptr_id", None), + "altchildmodel_ptr": getattr(obj, "altchildmodel_ptr_id", None), + } + + with self.assertNumQueries(5): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + obj.relation + for obj in PlainModel.objects.select_related("relation").order_by("pk") + ] + grouped_info = [object_info(obj) for obj in grouped_q] + self.assertListEqual( + grouped_info, [object_info(obj) for obj in [obj1, obj2, obj3, obj4, obj5, obj6, obj7]] + ) + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4, obj5, obj6, obj7]) + + def test_related_fetch_of_non_sequential_pks(self): + obj1 = ChildModel.objects.create(name="m1", other_name="c1") + obj2 = ParentModel.objects.create(name="m2") + + # FIXME use PK from table to get in sequential PKS + # from django.db import connection + # with connection.cursor() as cursor: + # cursor.execute('INSERT INTO "tests_childmodel" ("other_name", "parentmodel_ptr_id") VALUES (%s, %s)', ['fake', 1]) + + obj3 = ChildModel.objects.create(name="m3", other_name="c2") + obj4 = AltChildModel.objects.create(name="m4", other_name="acm1") + obj5 = ChildModel.objects.create(name="m5", other_name="c3") + obj6 = AltChildAsBaseModel.objects.create(name="m6", more_name="acab1") + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + PlainModel.objects.create(relation=obj5) + PlainModel.objects.create(relation=obj6) + + def object_info(obj): + return { + "pk": obj.pk, + "parentmodel_ptr": getattr(obj, "parentmodel_ptr_id", None), + "altchildmodel_ptr": getattr(obj, "altchildmodel_ptr_id", None), + } + + with self.assertNumQueries(4): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + obj.relation + for obj in PlainModel.objects.select_related("relation") + ] + grouped_info = [object_info(obj) for obj in grouped_q] + self.assertListEqual( + grouped_info, [object_info(obj) for obj in [obj1, obj2, obj3, obj4, obj5, obj6]] + ) + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4, obj5, obj6]) + + def test_normal_django_to_poly_related_give_poly_type_using_select_related_true(self): + obj1 = ParentModel.objects.create(name="m1") + obj2 = ChildModel.objects.create(name="m2", other_name="m2") + obj3 = ChildModel.objects.create(name="m1") + obj4 = AltChildAsBaseModel.objects.create( + name="ac2", other_name="ac2name", more_name="ac2_mn" + ) + + PlainModel.objects.create(relation=obj1) + PlainModel.objects.create(relation=obj2) + PlainModel.objects.create(relation=obj3) + PlainModel.objects.create(relation=obj4) + + with self.assertNumQueries(8): + # Queries will be + # * 1 for All PlainModels object (x1) + # * 1 for each relations ParentModel (x4) + # * 1 for each relations ChildModel is needed (x2) + # * 1 for each relations AltChildAsBaseModel is needed (x1) + multi_q = [ + # these obj.relation values will have their proper sub type + obj.relation + for obj in PlainModel.objects.all() + ] + multi_q_types = [type(obj) for obj in multi_q] + + with self.assertNumQueries(3): + grouped_q = [ + # these obj.relation values will all be ParentModel's + # unless we fix select related but should be their proper + # sub type by using PolymorphicRelatedQuerySetMixin + # ATM: we require 1 query fro each type. Although this can + # be reduced by specifying the relations to the polymorphic + # classes. BUT this has the downside of making the query have + # a large number of joins + obj.relation + for obj in PlainModel.objects.select_related() + ] + grouped_q_types = [type(obj) for obj in grouped_q] + + self.assertListEqual(multi_q_types, grouped_q_types) + self.assertListEqual(grouped_q, [obj1, obj2, obj3, obj4]) + + def test_prefetch_base_load_359(self): + obj1_1 = ModelShow1_plain.objects.create(field1="1") + obj2_1 = ModelShow2_plain.objects.create(field1="2", field2="1") + obj3_2 = ModelShow2_plain.objects.create(field1="3", field2="2") + + with self.assertNumQueries(1): + obj = ModelShow2_plain.objects.filter(pk=obj2_1.pk)[0] + _ = (obj.field1, obj.field1) + + def test_select_related_on_poly_classes(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + + ContentType.objects.get_for_models(PlainA, ModelExtraExternal, AltChildModel) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__ChildModel__link_on_child", + "relation__AltChildModel__link_on_altchild", + ).order_by("pk") + ) + with self.assertNumQueries(0): + assert obj_list[0].relation.name == "p1" + assert obj_list[1].relation.name == "c1" + assert obj_list[2].relation.name == "ac1" + assert obj_list[3].relation.name == "ac2" + obj_list[1].relation.link_on_child + obj_list[2].relation.link_on_altchild + obj_list[3].relation.link_on_altchild + + self.assertIsInstance(obj_list[0].relation, ParentModel) + self.assertIsInstance(obj_list[1].relation, ChildModel) + self.assertIsInstance(obj_list[2].relation, AltChildModel) + self.assertIsInstance(obj_list[3].relation, AltChildModel) + + def test_select_related_can_merge_fields(self): + # can we fetch the related object but only the minimal 'common' values + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + ContentType.objects.get_for_models(PlainA, ModelExtraExternal, AltChildModel) + base_query = PlainModel.objects.select_related( + "relation__ChildModel", + ) + base_query = base_query.select_related("relation__AltChildModel") + with self.assertNumQueries(1): + list(base_query) + + def test_select_related_on_poly_classes_simple(self): + # can we fetch the related object but only the minimal 'common' values + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__ChildModel", + "relation__AltChildModel", + ) + .order_by("pk") + .only( + "relation__name", + "relation__polymorphic_ctype", + ) + ) + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "ac1") + self.assertEqual(obj_list[3].relation.name, "ac2") + + self.assertIsInstance(obj_list[0].relation, ParentModel) + self.assertIsInstance(obj_list[1].relation, ChildModel) + self.assertIsInstance(obj_list[2].relation, AltChildModel) + self.assertIsInstance(obj_list[3].relation, AltChildModel) + + def test_we_can_upgrade_a_query_set_to_polymorphic_supports_already_ploy_qs(self): + base_qs = RefPlainModel.poly_objects.get_queryset() + self.assertIs(convert_to_polymorphic_queryset(base_qs), base_qs) + + def test_we_can_upgrade_a_query_set_to_polymorphic_supports_non_ploy_qs_on_ploy_object(self): + base_qs = RefPlainModel.objects.get_queryset() + self.assertIsNot(convert_to_polymorphic_queryset(base_qs), base_qs) + self.assertIsInstance( + convert_to_polymorphic_queryset(base_qs), PolymorphicRelatedQuerySetMixin + ) + + def test_we_can_upgrade_a_query_set_to_polymorphic_supports_non_ploy_managers_on_ploy_object( + self, + ): + base_qs = RefPlainModel.objects + self.assertIsNot(convert_to_polymorphic_queryset(base_qs), base_qs) + self.assertIsInstance( + convert_to_polymorphic_queryset(base_qs), PolymorphicRelatedQuerySetMixin + ) + + def test_we_can_upgrade_a_query_set_to_polymorphic(self): + # can we fetch the related object but only the minimal 'common' values + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = VanillaPlainModel.objects.create(relation=obj_p) + obj_p_2 = VanillaPlainModel.objects.create(relation=obj_c) + obj_p_3 = VanillaPlainModel.objects.create(relation=obj_ac1) + obj_p_4 = VanillaPlainModel.objects.create(relation=obj_ac2) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list(VanillaPlainModel.objects.order_by("pk")) + + with self.assertNumQueries(7): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "ac1") + self.assertEqual(obj_list[3].relation.name, "ac2") + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + convert_to_polymorphic_queryset(VanillaPlainModel.objects) + .select_related( + "relation", + "relation__ChildModel", + "relation__AltChildModel", + ) + .order_by("pk") + ) + + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].relation.name, "p1") + self.assertEqual(obj_list[1].relation.name, "c1") + self.assertEqual(obj_list[2].relation.name, "ac1") + self.assertEqual(obj_list[3].relation.name, "ac2") + + self.assertIsInstance(obj_list[0].relation, ParentModel) + self.assertIsInstance(obj_list[1].relation, ChildModel) + self.assertIsInstance(obj_list[2].relation, AltChildModel) + self.assertIsInstance(obj_list[3].relation, AltChildModel) + + def test_select_related_on_poly_classes_indirect_related(self): + # can we fetch the related object but only the minimal 'common' values + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + + robj_1 = RefPlainModel.objects.create(plainobj=obj_p_1) + robj_2 = RefPlainModel.objects.create(plainobj=obj_p_2) + robj_3 = RefPlainModel.objects.create(plainobj=obj_p_3) + robj_4 = RefPlainModel.objects.create(plainobj=obj_p_4) + + # Prefetch content_types + ContentType.objects.get_for_models(PlainModel, PlainA, ModelExtraExternal) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + RefPlainModel.poly_objects.select_related( + # "plainobj__relation", + "plainobj__relation", + "plainobj__relation__ChildModel__link_on_child", + "plainobj__relation__AltChildModel__link_on_altchild", + ).order_by("pk") + ) + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].plainobj.relation.name, "p1") + self.assertEqual(obj_list[1].plainobj.relation.name, "c1") + self.assertEqual(obj_list[2].plainobj.relation.name, "ac1") + self.assertEqual(obj_list[3].plainobj.relation.name, "ac2") + + self.assertIsInstance(obj_list[0].plainobj.relation, ParentModel) + self.assertIsInstance(obj_list[1].plainobj.relation, ChildModel) + self.assertIsInstance(obj_list[2].plainobj.relation, AltChildModel) + self.assertIsInstance(obj_list[3].plainobj.relation, AltChildModel) + + def test_select_related_fecth_all_poly_classes_indirect_related(self): + # can we fetch the related object but only the minimal 'common' values + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_ac2 = AltChildModel.objects.create( + name="ac2", other_name="ac2name", link_on_altchild=plain_a_obj_2 + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_ac2) + + robj_1 = RefPlainModel.objects.create(plainobj=obj_p_1) + robj_2 = RefPlainModel.objects.create(plainobj=obj_p_2) + robj_3 = RefPlainModel.objects.create(plainobj=obj_p_3) + robj_4 = RefPlainModel.objects.create(plainobj=obj_p_4) + + # Prefetch content_types + ContentType.objects.get_for_models( + AltChildAsBaseModel, + AltChildWithM2MModel, + ModelExtraExternal, + NonAutoPKChild, + PlainA, + PlainModel, + ) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + RefPlainModel.poly_objects.select_related( + # "plainobj__relation", + "plainobj__relation", + "plainobj__relation__*", + ).order_by("pk") + ) + with self.assertNumQueries(0): + self.assertEqual(obj_list[0].plainobj.relation.name, "p1") + self.assertEqual(obj_list[1].plainobj.relation.name, "c1") + self.assertEqual(obj_list[2].plainobj.relation.name, "ac1") + self.assertEqual(obj_list[3].plainobj.relation.name, "ac2") + + self.assertIsInstance(obj_list[0].plainobj.relation, ParentModel) + self.assertIsInstance(obj_list[1].plainobj.relation, ChildModel) + self.assertIsInstance(obj_list[2].plainobj.relation, AltChildModel) + self.assertIsInstance(obj_list[3].plainobj.relation, AltChildModel) + + def test_select_related_on_poly_classes_supports_multi_level_inheritance(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + plain_a_obj_2 = PlainA.objects.create(field1="f2") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_ac1 = AltChildModel.objects.create( + name="ac1", other_name="ac1name", link_on_altchild=plain_a_obj_1 + ) + obj_acab2 = AltChildAsBaseModel.objects.create( + name="ac2ab", + other_name="acab2name", + more_name="acab2_mn", + link_on_altchild=plain_a_obj_2, + ) + + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_ac1) + obj_p_4 = PlainModel.objects.create(relation=obj_acab2) + + ContentType.objects.get_for_models(PlainA, ModelExtraExternal) + + with self.assertNumQueries(1): + # pos 3 if i cannot do optimized select_related + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__ChildModel__link_on_child", + "relation__AltChildModel__link_on_altchild", + "relation__AltChildAsBaseModel__link_on_altchild", + ).order_by("pk") + ) + with self.assertNumQueries(0): + assert obj_list[0].relation.name == "p1" + assert obj_list[1].relation.name == "c1" + assert obj_list[2].relation.name == "ac1" + assert obj_list[3].relation.name == "ac2ab" + assert obj_list[3].relation.more_name == "acab2_mn" + obj_list[1].relation.link_on_child + obj_list[2].relation.link_on_altchild + obj_list[3].relation.link_on_altchild + + def test_select_related_on_poly_classes_with_modelname(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + extra_obj = ModelExtraExternal.objects.create(topic="t1") + obj_p = ParentModel.objects.create(name="p1") + obj_c = ChildModel.objects.create(name="c1", other_name="c1name", link_on_child=extra_obj) + obj_acab2 = AltChildAsBaseModel.objects.create( + name="acab2", + other_name="acab2name", + more_name="acab2_mn", + link_on_altchild=plain_a_obj_1, + ) + obj_p_1 = PlainModel.objects.create(relation=obj_p) + obj_p_2 = PlainModel.objects.create(relation=obj_c) + obj_p_3 = PlainModel.objects.create(relation=obj_acab2) + + ContentType.objects.get_for_models(PlainA, ModelExtraExternal, AltChildModel) + + with self.assertNumQueries(1): + obj_list = list( + PlainModel.objects.select_related( + "relation", + "relation__ChildModel__link_on_child", + "relation__AltChildAsBaseModel__link_on_altchild", + ).order_by("pk") + ) + + with self.assertNumQueries(0): + assert obj_list[0].relation.name == "p1" + assert obj_list[1].relation.name == "c1" + assert obj_list[2].relation.name == "acab2" + obj_list[1].relation.link_on_child + obj_list[2].relation.link_on_altchild + + def test_prefetch_related_from_basepoly(self): + obja1 = NonSymRelationA.objects.create(field_a="fa1", field_base="fa1") + obja2 = NonSymRelationA.objects.create(field_a="fa2", field_base="fa2") + objb1 = NonSymRelationB.objects.create(field_b="fb1", field_base="fb1") + objbc1 = NonSymRelationBC.objects.create(field_c="fbc1", field_base="fbc1") + + obja3 = NonSymRelationA.objects.create(field_a="fa3", field_base="fa3") + # NOTE: these are symmetric links + obja3.m2m.add(obja2) + obja3.m2m.add(objb1) + obja2.m2m.add(objbc1) + + # NOTE: prefetch content types so query asserts test data fetched. + ContentType.objects.get_for_model(NonSymRelationBase) + + with self.assertNumQueries(10): + # query for NonSymRelationBase (base) + # query for NonSymRelationA # level 1 (base) + # query for NonSymRelationB # level 1 (base) + # query for NonSymRelationBC # level 1 (base) + # query for prefetch links (m2m) + # query for NonSymRelationA # level 2 (m2m) + # query for NonSymRelationB # level 2 (m2m) + # query for NonSymRelationBC # level 2 (m2m) + # query for prefetch links (m2m__m2m) + # query for NonSymRelationA # level 3 (m2m__m2m) + # query for NonSymRelationB # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + # query for NonSymRelationC # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + + all_objs = { + obj.pk: obj + for obj in NonSymRelationBase.objects.prefetch_related("m2m", "m2m__m2m") + } + + with self.assertNumQueries(0): + relations = {obj.pk: set(obj.m2m.all()) for obj in all_objs.values()} + + with self.assertNumQueries(0): + sub_relations = {a.pk: set(a.m2m.all()) for a in all_objs.get(obja3.pk).m2m.all()} + + self.assertDictEqual( + { + obja1.pk: set(), + obja2.pk: set([objbc1]), + obja3.pk: set([obja2, objb1]), + objb1.pk: set([]), + objbc1.pk: set([]), + }, + relations, + ) + + self.assertDictEqual( + { + obja2.pk: set([objbc1]), + objb1.pk: set([]), + }, + sub_relations, + ) + + def test_prefetch_related_from_subclass(self): + obja1 = NonSymRelationA.objects.create(field_a="fa1", field_base="fa1") + obja2 = NonSymRelationA.objects.create(field_a="fa2", field_base="fa2") + objb1 = NonSymRelationB.objects.create(field_b="fb1", field_base="fb1") + objbc1 = NonSymRelationBC.objects.create(field_c="fbc1", field_base="fbc1") + + obja3 = NonSymRelationA.objects.create(field_a="fa3", field_base="fa3") + # NOTE: these are symmetric links + obja3.m2m.add(obja2) + obja3.m2m.add(objb1) + obja2.m2m.add(objbc1) + + # NOTE: prefetch content types so query asserts test data fetched. + ContentType.objects.get_for_model(NonSymRelationBase) + + with self.assertNumQueries(7): + # query for NonSymRelationA # level 1 (base) + # query for prefetch links (m2m) + # query for NonSymRelationA # level 2 (m2m) + # query for NonSymRelationB # level 2 (m2m) + # query for NonSymRelationBC # level 2 (m2m) + # query for prefetch links (m2m__m2m) + # query for NonSymRelationA # level 3 (m2m__m2m) + # query for NonSymRelationB # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + # query for NonSymRelationC # level 3 (m2m__m2m) [SKIPPED AS NO DATA] + + all_objs = { + obj.pk: obj for obj in NonSymRelationA.objects.prefetch_related("m2m", "m2m__m2m") + } + + with self.assertNumQueries(0): + relations = {obj.pk: set(obj.m2m.all()) for obj in all_objs.values()} + + with self.assertNumQueries(0): + sub_relations = {a.pk: set(a.m2m.all()) for a in all_objs.get(obja3.pk).m2m.all()} + + self.assertDictEqual( + { + obja1.pk: set(), + obja2.pk: set([objbc1]), + obja3.pk: set([obja2, objb1]), + }, + relations, + ) + + self.assertDictEqual( + { + obja2.pk: set([objbc1]), + objb1.pk: set([]), + }, + sub_relations, + ) + + def test_select_related_field_from_polymorphic_child_class(self): + # 198 + obj_p1 = ParentModel.objects.create(name="p1") + obj_p2 = ParentModel.objects.create(name="p2") + obj_p3 = ParentModel.objects.create(name="p4") + obj_c1 = ChildModel.objects.create(name="c1", other_name="c1name") + obj_c2 = ChildModel.objects.create(name="c2", other_name="c2name") + obj_ac1 = AltChildModel.objects.create(name="ac1", other_name="ac1name") + obj_ac2 = AltChildModel.objects.create(name="ac2", other_name="ac2name") + obj_ac3 = AltChildModel.objects.create(name="ac3", other_name="ac3name") + + with self.assertNumQueries(2): + # Queries will be + # * 1 for All ParentModel object (x4 +bases of all) + # * 1 for ChildModel object (x2) + # * 0 for AltChildModel object as from select_related (x3) + all_objs = [ + obj + for obj in ParentModel.objects.select_related( + "AltChildModel", + ) + ] + + def test_select_related_field_from_polymorphic_child_class_using_modelnames_level1(self): + # 198 + obj_p1 = ParentModel.objects.create(name="p1") + obj_p2 = ParentModel.objects.create(name="p2") + obj_p3 = ParentModel.objects.create(name="p4") + obj_c1 = ChildModel.objects.create(name="c1", other_name="c1name") + obj_c2 = ChildModel.objects.create(name="c2", other_name="c2name") + obj_ac1 = AltChildModel.objects.create(name="ac1", other_name="ac1name") + obj_ac2 = AltChildModel.objects.create(name="ac2", other_name="ac2name") + obj_ac3 = AltChildModel.objects.create(name="ac3", other_name="ac3name") + + with self.assertNumQueries(2): + # Queries will be + # * 1 for All ParentModel object (x4 +bases of all) + # * 1 for ChildModel object (x2) + # * 0 for AltChildModel object as from select_related (x3) + all_objs = [ + obj + for obj in ParentModel.objects.select_related( + "AltChildModel", + ) + ] + + def test_select_related_field_from_polymorphic_child_class_using_modelnames_multi_level(self): + plain_a_obj_1 = PlainA.objects.create(field1="f1") + + obj_p1 = ParentModel.objects.create(name="p1") + obj_acab2 = AltChildAsBaseModel.objects.create( + name="acab2", + other_name="acab2name", + more_name="acab2_mn", + link_on_altchild=plain_a_obj_1, + ) + obj_c1 = ChildModel.objects.create(name="c1", other_name="c1name") + obj_ac3 = ChildModel.objects.create(name="c2", other_name="c3name") + + # NOTE: prefetch content types so query asserts test data fetched. + ContentType.objects.get_for_model(AltChildModel) + + with self.assertNumQueries(2): + # Queries will be + # * 1 for All ParentModel object (x4 +bases of all) + # * 1 for ChildModel object (x1) + # * 0 for AltChildAsBaseModel object as from select_related (x1) + # * 0 for AltChildModel object as part of select_related form + # AltChildAsBaseModel (x1) + all_objs = [obj for obj in ParentModel.objects.select_related("AltChildAsBaseModel")] + + def test_prefetch_object_is_supported(self): + b1 = RelatingModel.objects.create() + b2 = RelatingModel.objects.create() + + rel1 = Model2A.objects.create(field1="A1") + rel2 = Model2B.objects.create(field1="A2", field2="B2") + + b1.many2many.add(rel1) + b2.many2many.add(rel2) + + rel2.delete(keep_parents=True) + + qs = RelatingModel.objects.order_by("pk").prefetch_related( + Prefetch("many2many", queryset=Model2A.objects.all(), to_attr="poly"), + Prefetch("many2many", queryset=Model2A.objects.non_polymorphic(), to_attr="non_poly"), + ) + + objects = list(qs) + assert len(objects[0].poly) == 1 + + # derived object was not fetched + assert len(objects[1].poly) == 0 + + # base object always found + assert len(objects[0].non_poly) == 1 + assert len(objects[1].non_poly) == 1 + + def test_prefetch_related_on_poly_classes_preserves_on_relations_annotations(self): + b1 = RelatingModel.objects.create() + b2 = RelatingModel.objects.create() + b3 = RelatingModel.objects.create() + + rel1 = Model2A.objects.create(field1="A1") + rel2 = Model2B.objects.create(field1="A2", field2="B2") + + b1.many2many.add(rel1) + b2.many2many.add(rel2) + b3.many2many.add(rel2) + + qs = RelatingModel.objects.order_by("pk").prefetch_related( + Prefetch( + "many2many", + queryset=Model2A.objects.annotate(Count("relatingmodel")), + to_attr="poly", + ) + ) + + objects = list(qs) + assert objects[0].poly[0].relatingmodel__count == 1 + assert objects[1].poly[0].relatingmodel__count == 2 + assert objects[2].poly[0].relatingmodel__count == 2 + + @expectedFailure + def test_prefetch_loading_relation_only_on_some_poly_model(self): + plain_a_obj_1 = PlainA.objects.create(field1="p1") + plain_a_obj_2 = PlainA.objects.create(field1="p2") + plain_a_obj_3 = PlainA.objects.create(field1="p3") + plain_a_obj_4 = PlainA.objects.create(field1="p4") + plain_a_obj_5 = PlainA.objects.create(field1="p5") + + ac_m2m_obj = AltChildWithM2MModel.objects.create( + other_name="o1", + ) + ac_m2m_obj.m2m.set([plain_a_obj_1, plain_a_obj_2, plain_a_obj_3]) + + cm_1 = ChildModel.objects.create(other_name="c1") + cm_2 = ChildModel.objects.create(other_name="c2") + cm_3 = ChildModel.objects.create(other_name="c3") + + acm_1 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_4) + acm_2 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_5) + + pm_1 = PlainModelWithM2M.objects.create(field1="pm1") + pm_2 = PlainModelWithM2M.objects.create(field1="pm2") + + pm_1.m2m.set([cm_1, cm_2]) + pm_2.m2m.set( + [ + cm_3, + ] + ) + + # NOTE: prefetch content types so query asserts test data fetched. + ContentType.objects.get_for_model(ParentModel) + + pm_2.m2m.set([ac_m2m_obj]) + with self.assertNumQueries(4): + # query for PlainModelWithM2M # level 1 (base) + # query for prefetch links (m2m) + # query for ChildModel # level 2 (m2m) + # query for AltChildWithM2MModel # level 2 (m2m) + qs = PlainModelWithM2M.objects.all() + qs = qs.prefetch_related("m2m__altchildmodel__altchildWithm2mmodel__m2m") + all_objs = list(qs) + + @expectedFailure + def test_prefetch_loading_relation_only_on_some_poly_model_using_modelnames(self): + plain_a_obj_1 = PlainA.objects.create(field1="p1") + plain_a_obj_2 = PlainA.objects.create(field1="p2") + plain_a_obj_3 = PlainA.objects.create(field1="p3") + plain_a_obj_4 = PlainA.objects.create(field1="p4") + plain_a_obj_5 = PlainA.objects.create(field1="p5") + + ac_m2m_obj = AltChildWithM2MModel.objects.create( + other_name="o1", + ) + ac_m2m_obj.m2m.set([plain_a_obj_1, plain_a_obj_2, plain_a_obj_3]) + + cm_1 = ChildModel.objects.create(other_name="c1") + cm_2 = ChildModel.objects.create(other_name="c2") + cm_3 = ChildModel.objects.create(other_name="c3") + + acm_1 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_4) + acm_2 = AltChildModel.objects.create(other_name="ac3", link_on_altchild=plain_a_obj_5) + + pm_1 = PlainModelWithM2M.objects.create(field1="pm1") + pm_2 = PlainModelWithM2M.objects.create(field1="pm2") + + pm_1.m2m.set([cm_1, cm_2]) + pm_2.m2m.set( + [ + cm_3, + ] + ) + + # NOTE: prefetch content types so query asserts test data fetched. + ContentType.objects.get_for_model(ParentModel) + + pm_2.m2m.set([ac_m2m_obj]) + with self.assertNumQueries(4): + # query for PlainModelWithM2M # level 1 (base) + # query for prefetch links (m2m) + # query for ChildModel # level 2 (m2m) + # query for AltChildWithM2MModel # level 2 (m2m) + qs = PlainModelWithM2M.objects.all() + qs = qs.prefetch_related("m2m__AltChildWithM2MModel__m2m") + all_objs = list(qs)