mixins.py 2.15 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57
from django.db.models import F, OrderBy


class OrderableAggMixin:
    def __init__(self, *expressions, ordering=(), **extra):
        if not isinstance(ordering, (list, tuple)):
            ordering = [ordering]
        ordering = ordering or []
        # Transform minus sign prefixed strings into an OrderBy() expression.
        ordering = (
            (
                OrderBy(F(o[1:]), descending=True)
                if isinstance(o, str) and o[0] == "-"
                else o
            )
            for o in ordering
        )
        super().__init__(*expressions, **extra)
        self.ordering = self._parse_expressions(*ordering)

    def resolve_expression(self, *args, **kwargs):
        self.ordering = [
            expr.resolve_expression(*args, **kwargs) for expr in self.ordering
        ]
        return super().resolve_expression(*args, **kwargs)

    def as_sql(self, compiler, connection):
        if self.ordering:
            ordering_params = []
            ordering_expr_sql = []
            for expr in self.ordering:
                expr_sql, expr_params = compiler.compile(expr)
                ordering_expr_sql.append(expr_sql)
                ordering_params.extend(expr_params)
            sql, sql_params = super().as_sql(
                compiler,
                connection,
                ordering=("ORDER BY " + ", ".join(ordering_expr_sql)),
            )
            return sql, sql_params + ordering_params
        return super().as_sql(compiler, connection, ordering="")

    def set_source_expressions(self, exprs):
        # Extract the ordering expressions because ORDER BY clause is handled
        # in a custom way.
        self.ordering = exprs[self._get_ordering_expressions_index() :]
        return super().set_source_expressions(
            exprs[: self._get_ordering_expressions_index()]
        )

    def get_source_expressions(self):
        return super().get_source_expressions() + self.ordering

    def _get_ordering_expressions_index(self):
        """Return the index at which the ordering expressions start."""
        source_expressions = self.get_source_expressions()
        return len(source_expressions) - len(self.ordering)