77
88from . import PolymorphicEmbeddedModelField
99from .array import ArrayField , ArrayLenTransform
10- from .embedded_model_array import KeyTransform as ArrayFieldKeyTransform
11- from .embedded_model_array import KeyTransformFactory as ArrayFieldKeyTransformFactory
10+ from .embedded_model_array import (
11+ EmbeddedModelArrayFieldTransform ,
12+ EmbeddedModelArrayFieldTransformFactory ,
13+ )
1214
1315
1416class PolymorphicEmbeddedModelArrayField (ArrayField ):
@@ -62,7 +64,15 @@ def get_transform(self, name):
6264 transform = super ().get_transform (name )
6365 if transform :
6466 return transform
65- return KeyTransformFactory (name , self )
67+ for model in self .base_field .embedded_models :
68+ with contextlib .suppress (FieldDoesNotExist ):
69+ field = model ._meta .get_field (name )
70+ break
71+ else :
72+ raise FieldDoesNotExist (
73+ f"The models of field '{ self .name } ' have no field named '{ name } '."
74+ )
75+ return PolymorphicArrayFieldTransformFactory (field )
6676
6777 def _get_lookup (self , lookup_name ):
6878 lookup = super ()._get_lookup (lookup_name )
@@ -79,32 +89,23 @@ def as_mql(self, compiler, connection):
7989 return EmbeddedModelArrayFieldLookups
8090
8191
82- class KeyTransform ( ArrayFieldKeyTransform ):
92+ class PolymorphicArrayFieldTransform ( EmbeddedModelArrayFieldTransform ):
8393 field_class_name = "PolymorphicEmbeddedModelArrayField"
8494
85- def __init__ (self , key_name , array_field , * args , ** kwargs ):
86- # Skip ArrayFieldKeyTransform .__init__()
95+ def __init__ (self , field , * args , ** kwargs ):
96+ # Skip EmbeddedModelArrayFieldTransform .__init__()
8797 Transform .__init__ (self , * args , ** kwargs )
88- self .array_field = array_field
89- self .key_name = key_name
90- for model in array_field .base_field .embedded_models :
91- with contextlib .suppress (FieldDoesNotExist ):
92- field = model ._meta .get_field (key_name )
93- break
94- else :
95- raise FieldDoesNotExist (
96- f"The models of field '{ array_field .name } ' have no field named '{ key_name } '."
97- )
9898 # Lookups iterate over the array of embedded models. A virtual column
9999 # of the queried field's type represents each element.
100100 column_target = field .clone ()
101- column_name = f"$item.{ key_name } "
101+ column_name = f"$item.{ field .column } "
102+ column_target .name = f"{ field .name } "
102103 column_target .db_column = column_name
103104 column_target .set_attributes_from_name (column_name )
104105 self ._lhs = Col (None , column_target )
105106 self ._sub_transform = None
106107
107108
108- class KeyTransformFactory ( ArrayFieldKeyTransformFactory ):
109+ class PolymorphicArrayFieldTransformFactory ( EmbeddedModelArrayFieldTransformFactory ):
109110 def __call__ (self , * args , ** kwargs ):
110- return KeyTransform (self .key_name , self . base_field , * args , ** kwargs )
111+ return PolymorphicArrayFieldTransform (self .field , * args , ** kwargs )
0 commit comments