diff --git a/netbox_custom_objects/__init__.py b/netbox_custom_objects/__init__.py index 476c024..da93f43 100644 --- a/netbox_custom_objects/__init__.py +++ b/netbox_custom_objects/__init__.py @@ -1,7 +1,6 @@ import sys import warnings -from django.core.exceptions import AppRegistryNotReady from django.db import transaction from django.db.utils import DatabaseError, OperationalError, ProgrammingError from netbox.plugins import PluginConfig @@ -52,16 +51,40 @@ class CustomObjectsPluginConfig(PluginConfig): required_settings = [] template_extensions = "template_content.template_extensions" + def ready(self): + from .models import CustomObjectType + from netbox_custom_objects.api.serializers import get_serializer_class + + # Suppress warnings about database calls during app initialization + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=RuntimeWarning, message=".*database.*" + ) + warnings.filterwarnings( + "ignore", category=UserWarning, message=".*database.*" + ) + + # Skip database calls if running during migration or if table doesn't exist + if is_running_migration() or not check_custom_object_type_table_exists(): + super().ready() + return + + qs = CustomObjectType.objects.all() + for obj in qs: + model = obj.get_model() + get_serializer_class(model) + + super().ready() + def get_model(self, model_name, require_ready=True): + self.apps.check_apps_ready() try: # if the model is already loaded, return it return super().get_model(model_name, require_ready) except LookupError: - try: - self.apps.check_apps_ready() - except AppRegistryNotReady: - raise + pass + model_name = model_name.lower() # only do database calls if we are sure the app is ready to avoid # Django warnings if "table" not in model_name.lower() or "model" not in model_name.lower(): diff --git a/netbox_custom_objects/field_types.py b/netbox_custom_objects/field_types.py index 981ce36..f10bce3 100644 --- a/netbox_custom_objects/field_types.py +++ b/netbox_custom_objects/field_types.py @@ -396,7 +396,6 @@ def get_model_field(self, field, **kwargs): to_model = content_type.model # Extract our custom parameters and keep only Django field parameters - generating_models = kwargs.pop('_generating_models', getattr(self, '_generating_models', set())) field_kwargs = {k: v for k, v in kwargs.items() if not k.startswith('_')} field_kwargs.update({"default": field.default, "unique": field.unique}) @@ -427,27 +426,7 @@ def get_model_field(self, field, **kwargs): return f else: # For cross-referential fields, use skip_object_fields to avoid infinite loops - # Check if we're in a recursion situation using the parameter or stored attribute - if generating_models and custom_object_type.id in generating_models: - # We're in a circular reference, don't call get_model() to prevent recursion - # Use a string reference instead - model_name = f"{APP_LABEL}.{custom_object_type.get_table_model_name(custom_object_type.id)}" - # Generate a unique related_name to prevent reverse accessor conflicts - table_model_name = field.custom_object_type.get_table_model_name( - field.custom_object_type.id - ).lower() - related_name = f"{table_model_name}_{field.name}_set" - f = models.ForeignKey( - model_name, - null=True, - blank=True, - on_delete=models.CASCADE, - related_name=related_name, - **field_kwargs - ) - return f - else: - model = custom_object_type.get_model(skip_object_fields=True) + model = custom_object_type.get_model(skip_object_fields=True) else: # to_model = content_type.model_class()._meta.object_name to_ct = f"{content_type.app_label}.{to_model}" @@ -479,19 +458,7 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): ) custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) - # Check if we're in a recursion situation - generating_models = getattr(self, '_generating_models', set()) - if generating_models and custom_object_type.id in generating_models: - # We're in a circular reference, don't call get_model() to prevent recursion - # Use a minimal approach or return a basic field - return DynamicModelChoiceField( - queryset=custom_object_type.get_model(skip_object_fields=True).objects.all(), - required=field.required, - # Remove initial=field.default to allow Django to handle instance data properly - selector=True, - ) - else: - model = custom_object_type.get_model() + model = custom_object_type.get_model() else: # This is a regular NetBox model model = content_type.model_class() @@ -801,20 +768,7 @@ def get_form_field(self, field, for_csv_import=False, **kwargs): ) custom_object_type = CustomObjectType.objects.get(pk=custom_object_type_id) - # For cross-referential fields, use skip_object_fields to avoid infinite loops - # Check if we're in a recursion situation using the parameter or stored attribute - generating_models = getattr(self, '_generating_models', set()) - if generating_models and custom_object_type.id in generating_models: - # We're in a circular reference, don't call get_model() to prevent recursion - # Use a minimal approach or return a basic field - return DynamicModelMultipleChoiceField( - queryset=custom_object_type.get_model(skip_object_fields=True).objects.all(), - required=field.required, - # Remove initial=field.default to allow Django to handle instance data properly - selector=True, - ) - else: - model = custom_object_type.get_model(skip_object_fields=True) + model = custom_object_type.get_model(skip_object_fields=True) else: # This is a regular NetBox model model = content_type.model_class() @@ -905,13 +859,7 @@ def after_model_generation(self, instance, model, field_name): # Self-referential field - resolve to current model to_model = model else: - # Cross-referential field - check for recursion before calling get_model() - generating_models = getattr(self, '_generating_models', set()) - if generating_models and custom_object_type.id in generating_models: - # We're in a circular reference, don't call get_model() to prevent recursion - return - else: - to_model = custom_object_type.get_model() + to_model = custom_object_type.get_model() else: to_ct = f"{content_type.app_label}.{content_type.model}" to_model = apps.get_model(to_ct) @@ -956,14 +904,7 @@ def create_m2m_table(self, instance, model, field_name): pk=custom_object_type_id ) - # Check if we're in a recursion situation - generating_models = getattr(self, '_generating_models', set()) - if generating_models and custom_object_type.id in generating_models: - # We're in a circular reference, don't call get_model() to prevent recursion - # Use a minimal approach or skip this field - return - else: - to_model = custom_object_type.get_model() + to_model = custom_object_type.get_model() else: to_model = content_type.model_class() diff --git a/netbox_custom_objects/models.py b/netbox_custom_objects/models.py index 2b33d6d..8ad9af5 100644 --- a/netbox_custom_objects/models.py +++ b/netbox_custom_objects/models.py @@ -1,5 +1,6 @@ import decimal import re +import threading from datetime import date, datetime import django_filters @@ -165,6 +166,8 @@ class CustomObjectType(PrimaryModel): _through_model_cache = ( {} ) # Now stores {custom_object_type_id: {through_model_name: through_model}} + _model_cache_locks = {} # Per-model locks to prevent race conditions + _global_lock = threading.RLock() # Global lock for managing per-model locks name = models.CharField( max_length=100, unique=True, @@ -219,25 +222,19 @@ def clear_model_cache(cls, custom_object_type_id=None): :param custom_object_type_id: ID of the CustomObjectType to clear cache for, or None to clear all """ - if custom_object_type_id is not None: - cls._model_cache.pop(custom_object_type_id, None) - cls._through_model_cache.pop(custom_object_type_id, None) - else: - cls._model_cache.clear() - cls._through_model_cache.clear() + with cls._global_lock: + if custom_object_type_id is not None: + cls._model_cache.pop(custom_object_type_id, None) + cls._through_model_cache.pop(custom_object_type_id, None) + cls._model_cache_locks.pop(custom_object_type_id, None) + else: + cls._model_cache.clear() + cls._through_model_cache.clear() + cls._model_cache_locks.clear() # Clear Django apps registry cache to ensure newly created models are recognized apps.get_models.cache_clear() - # Clear global recursion tracking when clearing cache - cls.clear_global_recursion_tracking() - - @classmethod - def clear_global_recursion_tracking(cls): - """Clear the global recursion tracking set.""" - if hasattr(cls, '_global_generating_models'): - cls._global_generating_models.clear() - @classmethod def get_cached_model(cls, custom_object_type_id): """ @@ -314,7 +311,6 @@ def _fetch_and_generate_field_attrs( self, fields, skip_object_fields=False, - generating_models=None, ): field_attrs = { "_primary_field_id": -1, @@ -338,51 +334,8 @@ def _fetch_and_generate_field_attrs( field_type = FIELD_TYPE_CLASS[field.type]() field_name = field.name - # Pass generating models set to field generation to prevent infinite loops - field_type._generating_models = generating_models - - # Check if we're in a recursion situation before generating the field - # Use depth-based recursion control: allow self-referential fields at level 0, skip at deeper levels - should_skip = False - - # Calculate depth correctly: depth 0 is when we're generating the main model - # depth 1+ is when we're generating related models recursively - current_depth = len(generating_models) - 1 if generating_models else 0 - - if field.type in [CustomFieldTypeChoices.TYPE_OBJECT, CustomFieldTypeChoices.TYPE_MULTIOBJECT]: - if field.related_object_type: - # Check if this field references the same CustomObjectType (self-referential) - if field.related_object_type.app_label == APP_LABEL: - # This is a custom object type - from django.contrib.contenttypes.models import ContentType - content_type = ContentType.objects.get(pk=field.related_object_type_id) - if content_type.app_label == APP_LABEL: - # Extract the custom object type ID from the model name - # The model name format is "table{id}model" or similar - model_name = content_type.model - - # Try to extract the ID from the model name - id_match = re.search(r'table(\d+)model', model_name, re.IGNORECASE) - if id_match: - custom_object_type_id = int(id_match.group(1)) - - if custom_object_type_id == self.id: - # This is a self-referential field - if current_depth == 0: - # At level 0, allow self-referential fields - should_skip = False - else: - # At deeper levels, skip self-referential fields to prevent infinite recursion - should_skip = True - - if should_skip: - # Skip this field to prevent further recursion - field_attrs["_skipped_fields"].add(field.name) - continue - field_attrs[field.name] = field_type.get_model_field( field, - _generating_models=generating_models, # Pass as prefixed parameter ) # Add to field objects only if the field was successfully generated @@ -492,64 +445,26 @@ def register_custom_object_search_index(self, model): def get_model( self, - fields=None, - manytomany_models=None, - app_label=None, skip_object_fields=False, - no_cache=False, - _generating_models=None, ): """ Generates a temporary Django model based on available fields that belong to this table. Returns cached model if available, otherwise generates and caches it. - :param fields: Extra table field instances that need to be added the model. - :type fields: list - :param manytomany_models: In some cases with related fields a model has to be - generated in order to generate that model. In order to prevent a - recursion loop we cache the generated models and pass those along. - :type manytomany_models: dict - :param app_label: In some cases with related fields, the related models must - have the same app_label. If passed along in this parameter, then the - generated model will use that one instead of generating a unique one. - :type app_label: Optional[String] :param skip_object_fields: Don't add object or multiobject fields to the model :type skip_object_fields: bool - :param no_cache: Don't cache the generated model or attempt to pull from cache - :type no_cache: bool - :param _generating_models: Internal parameter to track models being generated - :type _generating_models: set :return: The generated model. :rtype: Model """ - # Check if we have a cached model for this CustomObjectType - if self.is_model_cached(self.id) and not no_cache: + # Double-check pattern: check cache again after acquiring lock + if self.is_model_cached(self.id): model = self.get_cached_model(self.id) - # Ensure the serializer is registered even for cached models - from netbox_custom_objects.api.serializers import get_serializer_class - - get_serializer_class(model) return model - # Circular reference detection using class-level tracking - if not hasattr(CustomObjectType, '_global_generating_models'): - CustomObjectType._global_generating_models = set() - - if _generating_models is None: - _generating_models = CustomObjectType._global_generating_models - - # Add this model to the set of models being generated - _generating_models.add(self.id) - - if app_label is None: - app_label = APP_LABEL - + # Generate the model inside the lock to prevent race conditions model_name = self.get_table_model_name(self.pk) - if fields is None: - fields = [] - # TODO: Add other fields with "index" specified indexes = [] @@ -576,10 +491,10 @@ def get_model( } # Pass the generating models set to field generation + fields = [] field_attrs = self._fetch_and_generate_field_attrs( fields, skip_object_fields=skip_object_fields, - generating_models=_generating_models ) attrs.update(**field_attrs) @@ -612,44 +527,32 @@ def wrapped_post_through_setup(self, cls): TM.post_through_setup = original_post_through_setup # Register the main model with Django's app registry - try: - existing_model = apps.get_model(APP_LABEL, model_name) - # If model exists but is different, we have a problem - if existing_model is not model: - # Use the existing model to avoid conflicts - model = existing_model - except LookupError: - apps.register_model(APP_LABEL, model) + if model_name.lower() in apps.all_models[APP_LABEL]: + # Remove the existing model from all_models before registering the new one + del apps.all_models[APP_LABEL][model_name.lower()] - if not manytomany_models: - self._after_model_generation(attrs, model) + apps.register_model(APP_LABEL, model) - # Cache the generated model - if not no_cache: - self._model_cache[self.id] = model - # Do the clear cache now that we have it in the cache so there - # is no recursion. - apps.clear_cache() + self._after_model_generation(attrs, model) - # Register the serializer for this model - if not manytomany_models: - from netbox_custom_objects.api.serializers import get_serializer_class + # Cache the generated model + self._model_cache[self.id] = model - get_serializer_class(model) + # Do the clear cache now that we have it in the cache so there + # is no recursion. + apps.clear_cache() + ContentType.objects.clear_cache() # Register the global SearchIndex for this model self.register_custom_object_search_index(model) - # Clean up: remove this model from the set of models being generated - if _generating_models is not None: - _generating_models.discard(self.id) - # Also clean up from global tracking if this is the global set - if _generating_models is CustomObjectType._global_generating_models: - CustomObjectType._global_generating_models.discard(self.id) - # Clear global tracking when we're done to ensure clean state - if len(CustomObjectType._global_generating_models) == 0: - CustomObjectType._global_generating_models.clear() + return model + def get_model_with_serializer(self): + from netbox_custom_objects.api.serializers import get_serializer_class + model = self.get_model() + get_serializer_class(model) + self.register_custom_object_search_index(model) return model def create_model(self): diff --git a/netbox_custom_objects/views.py b/netbox_custom_objects/views.py index b89026b..c8dfad3 100644 --- a/netbox_custom_objects/views.py +++ b/netbox_custom_objects/views.py @@ -160,12 +160,12 @@ class CustomObjectTypeView(CustomObjectTableMixin, generic.ObjectView): def get_table(self, data, request, bulk_actions=True): self.custom_object_type = self.get_object(**self.kwargs) - model = self.custom_object_type.get_model() + model = self.custom_object_type.get_model_with_serializer() data = model.objects.all() return super().get_table(data, request, bulk_actions=False) def get_extra_context(self, request, instance): - model = instance.get_model() + model = instance.get_model_with_serializer() # Get fields and group them by group_name fields = instance.fields.all().order_by("group_name", "weight", "name") @@ -199,7 +199,7 @@ class CustomObjectTypeDeleteView(generic.ObjectDeleteView): def _get_dependent_objects(self, obj): dependent_objects = super()._get_dependent_objects(obj) - model = obj.get_model() + model = obj.get_model_with_serializer() dependent_objects[model] = list(model.objects.all()) # Find CustomObjectTypeFields that reference this CustomObjectType @@ -243,7 +243,7 @@ def get(self, request, *args, **kwargs): obj = self.get_object(**kwargs) form = ConfirmationForm(initial=request.GET) - model = obj.custom_object_type.get_model() + model = obj.custom_object_type.get_model_with_serializer() kwargs = { f"{obj.name}__isnull": False, } @@ -280,7 +280,7 @@ def get(self, request, *args, **kwargs): def _get_dependent_objects(self, obj): dependent_objects = super()._get_dependent_objects(obj) - model = obj.custom_object_type.get_model() + model = obj.custom_object_type.get_model_with_serializer() kwargs = { f"{obj.name}__isnull": False, } @@ -332,7 +332,7 @@ def get_queryset(self, request): self.custom_object_type = get_object_or_404( CustomObjectType, slug=custom_object_type ) - model = self.custom_object_type.get_model() + model = self.custom_object_type.get_model_with_serializer() return model.objects.all() def get_filterset(self): @@ -378,7 +378,7 @@ def get_queryset(self, request): object_type = get_object_or_404( CustomObjectType, slug=custom_object_type ) - model = object_type.get_model() + model = object_type.get_model_with_serializer() return model.objects.all() def get_object(self, **kwargs): @@ -386,7 +386,7 @@ def get_object(self, **kwargs): object_type = get_object_or_404( CustomObjectType, slug=custom_object_type ) - model = object_type.get_model() + model = object_type.get_model_with_serializer() # Filter out custom_object_type from kwargs for the object lookup lookup_kwargs = { k: v for k, v in self.kwargs.items() if k != "custom_object_type" @@ -436,7 +436,8 @@ def get_object(self, **kwargs): object_type = get_object_or_404( CustomObjectType, slug=custom_object_type ) - model = object_type.get_model() + model = object_type.get_model_with_serializer() + if not self.kwargs.get("pk", None): # We're creating a new object return model() @@ -593,7 +594,7 @@ def get_object(self, **kwargs): object_type = get_object_or_404( CustomObjectType, slug=custom_object_type ) - model = object_type.get_model() + model = object_type.get_model_with_serializer() return get_object_or_404(model.objects.all(), **self.kwargs) def get_return_url(self, request, obj=None): @@ -633,7 +634,7 @@ def get_queryset(self, request): self.custom_object_type = CustomObjectType.objects.get( slug=custom_object_type ) - model = self.custom_object_type.get_model() + model = self.custom_object_type.get_model_with_serializer() return model.objects.all() def get_form(self, queryset): @@ -691,7 +692,7 @@ def get_queryset(self, request): self.custom_object_type = CustomObjectType.objects.get( slug=custom_object_type ) - model = self.custom_object_type.get_model() + model = self.custom_object_type.get_model_with_serializer() return model.objects.all() @@ -720,7 +721,7 @@ def get_queryset(self, request): self.custom_object_type = CustomObjectType.objects.get( name__iexact=custom_object_type ) - model = self.custom_object_type.get_model() + model = self.custom_object_type.get_model_with_serializer() return model.objects.all() def get_model_form(self, queryset): @@ -772,7 +773,7 @@ def get(self, request, custom_object_type, **kwargs): object_type = get_object_or_404( CustomObjectType, slug=custom_object_type ) - model = object_type.get_model() + model = object_type.get_model_with_serializer() # Get the specific object lookup_kwargs = {k: v for k, v in kwargs.items() if k != "custom_object_type"} @@ -844,7 +845,7 @@ def get(self, request, custom_object_type, **kwargs): object_type = get_object_or_404( CustomObjectType, slug=custom_object_type ) - model = object_type.get_model() + model = object_type.get_model_with_serializer() # Get the specific object lookup_kwargs = {k: v for k, v in kwargs.items() if k != "custom_object_type"}