indexes.py 11.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295
from django.db.backends.utils import names_digest, split_identifier
from django.db.models.expressions import Col, ExpressionList, F, Func, OrderBy
from django.db.models.functions import Collate
from django.db.models.query_utils import Q
from django.db.models.sql import Query
from django.utils.functional import partition

__all__ = ["Index"]


class Index:
    suffix = "idx"
    # The max length of the name of the index (restricted to 30 for
    # cross-database compatibility with Oracle)
    max_name_length = 30

    def __init__(
        self,
        *expressions,
        fields=(),
        name=None,
        db_tablespace=None,
        opclasses=(),
        condition=None,
        include=None,
    ):
        if opclasses and not name:
            raise ValueError("An index must be named to use opclasses.")
        if not isinstance(condition, (type(None), Q)):
            raise ValueError("Index.condition must be a Q instance.")
        if condition and not name:
            raise ValueError("An index must be named to use condition.")
        if not isinstance(fields, (list, tuple)):
            raise ValueError("Index.fields must be a list or tuple.")
        if not isinstance(opclasses, (list, tuple)):
            raise ValueError("Index.opclasses must be a list or tuple.")
        if not expressions and not fields:
            raise ValueError(
                "At least one field or expression is required to define an index."
            )
        if expressions and fields:
            raise ValueError(
                "Index.fields and expressions are mutually exclusive.",
            )
        if expressions and not name:
            raise ValueError("An index must be named to use expressions.")
        if expressions and opclasses:
            raise ValueError(
                "Index.opclasses cannot be used with expressions. Use "
                "django.contrib.postgres.indexes.OpClass() instead."
            )
        if opclasses and len(fields) != len(opclasses):
            raise ValueError(
                "Index.fields and Index.opclasses must have the same number of "
                "elements."
            )
        if fields and not all(isinstance(field, str) for field in fields):
            raise ValueError("Index.fields must contain only strings with field names.")
        if include and not name:
            raise ValueError("A covering index must be named.")
        if not isinstance(include, (type(None), list, tuple)):
            raise ValueError("Index.include must be a list or tuple.")
        self.fields = list(fields)
        # A list of 2-tuple with the field name and ordering ('' or 'DESC').
        self.fields_orders = [
            (field_name[1:], "DESC") if field_name.startswith("-") else (field_name, "")
            for field_name in self.fields
        ]
        self.name = name or ""
        self.db_tablespace = db_tablespace
        self.opclasses = opclasses
        self.condition = condition
        self.include = tuple(include) if include else ()
        self.expressions = tuple(
            F(expression) if isinstance(expression, str) else expression
            for expression in expressions
        )

    @property
    def contains_expressions(self):
        return bool(self.expressions)

    def _get_condition_sql(self, model, schema_editor):
        if self.condition is None:
            return None
        query = Query(model=model, alias_cols=False)
        where = query.build_where(self.condition)
        compiler = query.get_compiler(connection=schema_editor.connection)
        sql, params = where.as_sql(compiler, schema_editor.connection)
        return sql % tuple(schema_editor.quote_value(p) for p in params)

    def create_sql(self, model, schema_editor, using="", **kwargs):
        include = [
            model._meta.get_field(field_name).column for field_name in self.include
        ]
        condition = self._get_condition_sql(model, schema_editor)
        if self.expressions:
            index_expressions = []
            for expression in self.expressions:
                index_expression = IndexExpression(expression)
                index_expression.set_wrapper_classes(schema_editor.connection)
                index_expressions.append(index_expression)
            expressions = ExpressionList(*index_expressions).resolve_expression(
                Query(model, alias_cols=False),
            )
            fields = None
            col_suffixes = None
        else:
            fields = [
                model._meta.get_field(field_name)
                for field_name, _ in self.fields_orders
            ]
            if schema_editor.connection.features.supports_index_column_ordering:
                col_suffixes = [order[1] for order in self.fields_orders]
            else:
                col_suffixes = [""] * len(self.fields_orders)
            expressions = None
        return schema_editor._create_index_sql(
            model,
            fields=fields,
            name=self.name,
            using=using,
            db_tablespace=self.db_tablespace,
            col_suffixes=col_suffixes,
            opclasses=self.opclasses,
            condition=condition,
            include=include,
            expressions=expressions,
            **kwargs,
        )

    def remove_sql(self, model, schema_editor, **kwargs):
        return schema_editor._delete_index_sql(model, self.name, **kwargs)

    def deconstruct(self):
        path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__)
        path = path.replace("django.db.models.indexes", "django.db.models")
        kwargs = {"name": self.name}
        if self.fields:
            kwargs["fields"] = self.fields
        if self.db_tablespace is not None:
            kwargs["db_tablespace"] = self.db_tablespace
        if self.opclasses:
            kwargs["opclasses"] = self.opclasses
        if self.condition:
            kwargs["condition"] = self.condition
        if self.include:
            kwargs["include"] = self.include
        return (path, self.expressions, kwargs)

    def clone(self):
        """Create a copy of this Index."""
        _, args, kwargs = self.deconstruct()
        return self.__class__(*args, **kwargs)

    def set_name_with_model(self, model):
        """
        Generate a unique name for the index.

        The name is divided into 3 parts - table name (12 chars), field name
        (8 chars) and unique hash + suffix (10 chars). Each part is made to
        fit its size by truncating the excess length.
        """
        _, table_name = split_identifier(model._meta.db_table)
        column_names = [
            model._meta.get_field(field_name).column
            for field_name, order in self.fields_orders
        ]
        column_names_with_order = [
            (("-%s" if order else "%s") % column_name)
            for column_name, (field_name, order) in zip(
                column_names, self.fields_orders
            )
        ]
        # The length of the parts of the name is based on the default max
        # length of 30 characters.
        hash_data = [table_name] + column_names_with_order + [self.suffix]
        self.name = "%s_%s_%s" % (
            table_name[:11],
            column_names[0][:7],
            "%s_%s" % (names_digest(*hash_data, length=6), self.suffix),
        )
        if len(self.name) > self.max_name_length:
            raise ValueError(
                "Index too long for multiple database support. Is self.suffix "
                "longer than 3 characters?"
            )
        if self.name[0] == "_" or self.name[0].isdigit():
            self.name = "D%s" % self.name[1:]

    def __repr__(self):
        return "<%s:%s%s%s%s%s%s%s>" % (
            self.__class__.__qualname__,
            "" if not self.fields else " fields=%s" % repr(self.fields),
            "" if not self.expressions else " expressions=%s" % repr(self.expressions),
            "" if not self.name else " name=%s" % repr(self.name),
            ""
            if self.db_tablespace is None
            else " db_tablespace=%s" % repr(self.db_tablespace),
            "" if self.condition is None else " condition=%s" % self.condition,
            "" if not self.include else " include=%s" % repr(self.include),
            "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses),
        )

    def __eq__(self, other):
        if self.__class__ == other.__class__:
            return self.deconstruct() == other.deconstruct()
        return NotImplemented


class IndexExpression(Func):
    """Order and wrap expressions for CREATE INDEX statements."""

    template = "%(expressions)s"
    wrapper_classes = (OrderBy, Collate)

    def set_wrapper_classes(self, connection=None):
        # Some databases (e.g. MySQL) treats COLLATE as an indexed expression.
        if connection and connection.features.collate_as_index_expression:
            self.wrapper_classes = tuple(
                [
                    wrapper_cls
                    for wrapper_cls in self.wrapper_classes
                    if wrapper_cls is not Collate
                ]
            )

    @classmethod
    def register_wrappers(cls, *wrapper_classes):
        cls.wrapper_classes = wrapper_classes

    def resolve_expression(
        self,
        query=None,
        allow_joins=True,
        reuse=None,
        summarize=False,
        for_save=False,
    ):
        expressions = list(self.flatten())
        # Split expressions and wrappers.
        index_expressions, wrappers = partition(
            lambda e: isinstance(e, self.wrapper_classes),
            expressions,
        )
        wrapper_types = [type(wrapper) for wrapper in wrappers]
        if len(wrapper_types) != len(set(wrapper_types)):
            raise ValueError(
                "Multiple references to %s can't be used in an indexed "
                "expression."
                % ", ".join(
                    [wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
                )
            )
        if expressions[1 : len(wrappers) + 1] != wrappers:
            raise ValueError(
                "%s must be topmost expressions in an indexed expression."
                % ", ".join(
                    [wrapper_cls.__qualname__ for wrapper_cls in self.wrapper_classes]
                )
            )
        # Wrap expressions in parentheses if they are not column references.
        root_expression = index_expressions[1]
        resolve_root_expression = root_expression.resolve_expression(
            query,
            allow_joins,
            reuse,
            summarize,
            for_save,
        )
        if not isinstance(resolve_root_expression, Col):
            root_expression = Func(root_expression, template="(%(expressions)s)")

        if wrappers:
            # Order wrappers and set their expressions.
            wrappers = sorted(
                wrappers,
                key=lambda w: self.wrapper_classes.index(type(w)),
            )
            wrappers = [wrapper.copy() for wrapper in wrappers]
            for i, wrapper in enumerate(wrappers[:-1]):
                wrapper.set_source_expressions([wrappers[i + 1]])
            # Set the root expression on the deepest wrapper.
            wrappers[-1].set_source_expressions([root_expression])
            self.set_source_expressions([wrappers[0]])
        else:
            # Use the root expression, if there are no wrappers.
            self.set_source_expressions([root_expression])
        return super().resolve_expression(
            query, allow_joins, reuse, summarize, for_save
        )

    def as_sqlite(self, compiler, connection, **extra_context):
        # Casting to numeric is unnecessary.
        return self.as_sql(compiler, connection, **extra_context)