| 
 | 1 | +import itertools  | 
 | 2 | + | 
 | 3 | +from django.db import (  | 
 | 4 | +    connection,  | 
 | 5 | +)  | 
 | 6 | +from django.test import (  | 
 | 7 | +    TransactionTestCase,  | 
 | 8 | +)  | 
 | 9 | + | 
 | 10 | +from .models import Address, Author, Book, new_apps  | 
 | 11 | + | 
 | 12 | + | 
 | 13 | +class SchemaTests(TransactionTestCase):  | 
 | 14 | +    available_apps = []  | 
 | 15 | +    models = [Address, Author, Book]  | 
 | 16 | + | 
 | 17 | +    # Utility functions  | 
 | 18 | + | 
 | 19 | +    def setUp(self):  | 
 | 20 | +        # local_models should contain test dependent model classes that will be  | 
 | 21 | +        # automatically removed from the app cache on test tear down.  | 
 | 22 | +        self.local_models = []  | 
 | 23 | +        # isolated_local_models contains models that are in test methods  | 
 | 24 | +        # decorated with @isolate_apps.  | 
 | 25 | +        self.isolated_local_models = []  | 
 | 26 | + | 
 | 27 | +    def tearDown(self):  | 
 | 28 | +        # Delete any tables made for our models  | 
 | 29 | +        self.delete_tables()  | 
 | 30 | +        new_apps.clear_cache()  | 
 | 31 | +        for model in new_apps.get_models():  | 
 | 32 | +            model._meta._expire_cache()  | 
 | 33 | +        if "schema" in new_apps.all_models:  | 
 | 34 | +            for model in self.local_models:  | 
 | 35 | +                for many_to_many in model._meta.many_to_many:  | 
 | 36 | +                    through = many_to_many.remote_field.through  | 
 | 37 | +                    if through and through._meta.auto_created:  | 
 | 38 | +                        del new_apps.all_models["schema"][through._meta.model_name]  | 
 | 39 | +                del new_apps.all_models["schema"][model._meta.model_name]  | 
 | 40 | +        if self.isolated_local_models:  | 
 | 41 | +            with connection.schema_editor() as editor:  | 
 | 42 | +                for model in self.isolated_local_models:  | 
 | 43 | +                    editor.delete_model(model)  | 
 | 44 | + | 
 | 45 | +    def delete_tables(self):  | 
 | 46 | +        "Deletes all model tables for our models for a clean test environment"  | 
 | 47 | +        converter = connection.introspection.identifier_converter  | 
 | 48 | +        with connection.schema_editor() as editor:  | 
 | 49 | +            connection.disable_constraint_checking()  | 
 | 50 | +            table_names = connection.introspection.table_names()  | 
 | 51 | +            if connection.features.ignores_table_name_case:  | 
 | 52 | +                table_names = [table_name.lower() for table_name in table_names]  | 
 | 53 | +            for model in itertools.chain(SchemaTests.models, self.local_models):  | 
 | 54 | +                tbl = converter(model._meta.db_table)  | 
 | 55 | +                if connection.features.ignores_table_name_case:  | 
 | 56 | +                    tbl = tbl.lower()  | 
 | 57 | +                if tbl in table_names:  | 
 | 58 | +                    editor.delete_model(model)  | 
 | 59 | +                    table_names.remove(tbl)  | 
 | 60 | +            connection.enable_constraint_checking()  | 
 | 61 | + | 
 | 62 | +    def column_classes(self, model):  | 
 | 63 | +        with connection.cursor() as cursor:  | 
 | 64 | +            columns = {  | 
 | 65 | +                d[0]: (connection.introspection.get_field_type(d[1], d), d)  | 
 | 66 | +                for d in connection.introspection.get_table_description(  | 
 | 67 | +                    cursor,  | 
 | 68 | +                    model._meta.db_table,  | 
 | 69 | +                )  | 
 | 70 | +            }  | 
 | 71 | +        # SQLite has a different format for field_type  | 
 | 72 | +        for name, (type, desc) in columns.items():  | 
 | 73 | +            if isinstance(type, tuple):  | 
 | 74 | +                columns[name] = (type[0], desc)  | 
 | 75 | +        return columns  | 
 | 76 | + | 
 | 77 | +    def get_primary_key(self, table):  | 
 | 78 | +        with connection.cursor() as cursor:  | 
 | 79 | +            return connection.introspection.get_primary_key_column(cursor, table)  | 
 | 80 | + | 
 | 81 | +    def get_indexes(self, table):  | 
 | 82 | +        """  | 
 | 83 | +        Get the indexes on the table using a new cursor.  | 
 | 84 | +        """  | 
 | 85 | +        with connection.cursor() as cursor:  | 
 | 86 | +            return [  | 
 | 87 | +                c["columns"][0]  | 
 | 88 | +                for c in connection.introspection.get_constraints(cursor, table).values()  | 
 | 89 | +                if c["index"] and len(c["columns"]) == 1  | 
 | 90 | +            ]  | 
 | 91 | + | 
 | 92 | +    def get_uniques(self, table):  | 
 | 93 | +        with connection.cursor() as cursor:  | 
 | 94 | +            return [  | 
 | 95 | +                c["columns"][0]  | 
 | 96 | +                for c in connection.introspection.get_constraints(cursor, table).values()  | 
 | 97 | +                if c["unique"] and len(c["columns"]) == 1  | 
 | 98 | +            ]  | 
 | 99 | + | 
 | 100 | +    def get_constraints(self, table):  | 
 | 101 | +        """  | 
 | 102 | +        Get the constraints on a table using a new cursor.  | 
 | 103 | +        """  | 
 | 104 | +        with connection.cursor() as cursor:  | 
 | 105 | +            return connection.introspection.get_constraints(cursor, table)  | 
 | 106 | + | 
 | 107 | +    def get_constraints_for_column(self, model, column_name):  | 
 | 108 | +        constraints = self.get_constraints(model._meta.db_table)  | 
 | 109 | +        constraints_for_column = []  | 
 | 110 | +        for name, details in constraints.items():  | 
 | 111 | +            if details["columns"] == [column_name]:  | 
 | 112 | +                constraints_for_column.append(name)  | 
 | 113 | +        return sorted(constraints_for_column)  | 
 | 114 | + | 
 | 115 | +    def get_constraint_opclasses(self, constraint_name):  | 
 | 116 | +        with connection.cursor() as cursor:  | 
 | 117 | +            sql = """  | 
 | 118 | +                SELECT opcname  | 
 | 119 | +                FROM pg_opclass AS oc  | 
 | 120 | +                JOIN pg_index as i on oc.oid = ANY(i.indclass)  | 
 | 121 | +                JOIN pg_class as c on c.oid = i.indexrelid  | 
 | 122 | +                WHERE c.relname = %s  | 
 | 123 | +            """  | 
 | 124 | +            cursor.execute(sql, [constraint_name])  | 
 | 125 | +            return [row[0] for row in cursor.fetchall()]  | 
 | 126 | + | 
 | 127 | +    def check_added_field_default(  | 
 | 128 | +        self,  | 
 | 129 | +        schema_editor,  | 
 | 130 | +        model,  | 
 | 131 | +        field,  | 
 | 132 | +        field_name,  | 
 | 133 | +        expected_default,  | 
 | 134 | +        cast_function=None,  | 
 | 135 | +    ):  | 
 | 136 | +        schema_editor.add_field(model, field)  | 
 | 137 | +        database_default = connection.database[model._meta.db_table].find_one().get(field_name)  | 
 | 138 | +        # cursor.execute(  | 
 | 139 | +        #     "SELECT {} FROM {};".format(field_name, model._meta.db_table)  | 
 | 140 | +        # )  | 
 | 141 | +        # database_default = cursor.fetchall()[0][0]  | 
 | 142 | +        if cast_function and type(database_default) is not type(expected_default):  | 
 | 143 | +            database_default = cast_function(database_default)  | 
 | 144 | +        self.assertEqual(database_default, expected_default)  | 
 | 145 | + | 
 | 146 | +    def get_constraints_count(self, table, column, fk_to):  | 
 | 147 | +        """  | 
 | 148 | +        Return a dict with keys 'fks', 'uniques, and 'indexes' indicating the  | 
 | 149 | +        number of foreign keys, unique constraints, and indexes on  | 
 | 150 | +        `table`.`column`. The `fk_to` argument is a 2-tuple specifying the  | 
 | 151 | +        expected foreign key relationship's (table, column).  | 
 | 152 | +        """  | 
 | 153 | +        with connection.cursor() as cursor:  | 
 | 154 | +            constraints = connection.introspection.get_constraints(cursor, table)  | 
 | 155 | +        counts = {"fks": 0, "uniques": 0, "indexes": 0}  | 
 | 156 | +        for c in constraints.values():  | 
 | 157 | +            if c["columns"] == [column]:  | 
 | 158 | +                if c["foreign_key"] == fk_to:  | 
 | 159 | +                    counts["fks"] += 1  | 
 | 160 | +                if c["unique"]:  | 
 | 161 | +                    counts["uniques"] += 1  | 
 | 162 | +                elif c["index"]:  | 
 | 163 | +                    counts["indexes"] += 1  | 
 | 164 | +        return counts  | 
 | 165 | + | 
 | 166 | +    def get_column_collation(self, table, column):  | 
 | 167 | +        with connection.cursor() as cursor:  | 
 | 168 | +            return next(  | 
 | 169 | +                f.collation  | 
 | 170 | +                for f in connection.introspection.get_table_description(cursor, table)  | 
 | 171 | +                if f.name == column  | 
 | 172 | +            )  | 
 | 173 | + | 
 | 174 | +    def get_column_comment(self, table, column):  | 
 | 175 | +        with connection.cursor() as cursor:  | 
 | 176 | +            return next(  | 
 | 177 | +                f.comment  | 
 | 178 | +                for f in connection.introspection.get_table_description(cursor, table)  | 
 | 179 | +                if f.name == column  | 
 | 180 | +            )  | 
 | 181 | + | 
 | 182 | +    def get_table_comment(self, table):  | 
 | 183 | +        with connection.cursor() as cursor:  | 
 | 184 | +            return next(  | 
 | 185 | +                t.comment  | 
 | 186 | +                for t in connection.introspection.get_table_list(cursor)  | 
 | 187 | +                if t.name == table  | 
 | 188 | +            )  | 
 | 189 | + | 
 | 190 | +    def assert_column_comment_not_exists(self, table, column):  | 
 | 191 | +        with connection.cursor() as cursor:  | 
 | 192 | +            columns = connection.introspection.get_table_description(cursor, table)  | 
 | 193 | +        self.assertFalse(any(c.name == column and c.comment for c in columns))  | 
 | 194 | + | 
 | 195 | +    def assertIndexOrder(self, table, index, order):  | 
 | 196 | +        constraints = self.get_constraints(table)  | 
 | 197 | +        self.assertIn(index, constraints)  | 
 | 198 | +        index_orders = constraints[index]["orders"]  | 
 | 199 | +        self.assertTrue(  | 
 | 200 | +            all(val == expected for val, expected in zip(index_orders, order, strict=True))  | 
 | 201 | +        )  | 
 | 202 | + | 
 | 203 | +    def assertForeignKeyExists(self, model, column, expected_fk_table, field="id"):  | 
 | 204 | +        """  | 
 | 205 | +        Fail if the FK constraint on `model.Meta.db_table`.`column` to  | 
 | 206 | +        `expected_fk_table`.id doesn't exist.  | 
 | 207 | +        """  | 
 | 208 | +        if not connection.features.can_introspect_foreign_keys:  | 
 | 209 | +            return  | 
 | 210 | +        constraints = self.get_constraints(model._meta.db_table)  | 
 | 211 | +        constraint_fk = None  | 
 | 212 | +        for details in constraints.values():  | 
 | 213 | +            if details["columns"] == [column] and details["foreign_key"]:  | 
 | 214 | +                constraint_fk = details["foreign_key"]  | 
 | 215 | +                break  | 
 | 216 | +        self.assertEqual(constraint_fk, (expected_fk_table, field))  | 
 | 217 | + | 
 | 218 | +    def assertForeignKeyNotExists(self, model, column, expected_fk_table):  | 
 | 219 | +        if not connection.features.can_introspect_foreign_keys:  | 
 | 220 | +            return  | 
 | 221 | +        with self.assertRaises(AssertionError):  | 
 | 222 | +            self.assertForeignKeyExists(model, column, expected_fk_table)  | 
 | 223 | + | 
 | 224 | +    def assertTableExists(self, model):  | 
 | 225 | +        self.assertIn(model._meta.db_table, connection.introspection.table_names())  | 
 | 226 | + | 
 | 227 | +    def assertTableNotExists(self, model):  | 
 | 228 | +        self.assertNotIn(model._meta.db_table, connection.introspection.table_names())  | 
 | 229 | + | 
 | 230 | +    # Tests  | 
 | 231 | +    def test_embedded_index(self):  | 
 | 232 | +        """db_index on an embedded model."""  | 
 | 233 | +        with connection.schema_editor() as editor:  | 
 | 234 | +            # Create the table  | 
 | 235 | +            editor.create_model(Book)  | 
 | 236 | +            # The table is there  | 
 | 237 | +            self.assertTableExists(Book)  | 
 | 238 | +            # Embedded indexes are created.  | 
 | 239 | +            self.assertEqual(  | 
 | 240 | +                self.get_constraints_for_column(Book, "author.age"),  | 
 | 241 | +                ["schema__book_author.age_dc08100b"],  | 
 | 242 | +            )  | 
 | 243 | +            self.assertEqual(  | 
 | 244 | +                self.get_constraints_for_column(Book, "author.address.zip_code"),  | 
 | 245 | +                ["schema__book_author.address.zip_code_7b9a9307"],  | 
 | 246 | +            )  | 
 | 247 | +            # Clean up that table  | 
 | 248 | +            editor.delete_model(Author)  | 
 | 249 | +            # Indexes are gone.  | 
 | 250 | +            self.assertEqual(  | 
 | 251 | +                self.get_constraints_for_column(Author, "author.address.zip_code"),  | 
 | 252 | +                [],  | 
 | 253 | +            )  | 
 | 254 | +        # The table is gone  | 
 | 255 | +        self.assertTableNotExists(Author)  | 
0 commit comments