from enum import Enum from django.db.models.expressions import ExpressionList, F from django.db.models.indexes import IndexExpression from django.db.models.query_utils import Q from django.db.models.sql.query import Query __all__ = ["CheckConstraint", "Deferrable", "UniqueConstraint"] class BaseConstraint: def __init__(self, name): self.name = name @property def contains_expressions(self): return False def constraint_sql(self, model, schema_editor): raise NotImplementedError("This method must be implemented by a subclass.") def create_sql(self, model, schema_editor): raise NotImplementedError("This method must be implemented by a subclass.") def remove_sql(self, model, schema_editor): raise NotImplementedError("This method must be implemented by a subclass.") def deconstruct(self): path = "%s.%s" % (self.__class__.__module__, self.__class__.__name__) path = path.replace("django.db.models.constraints", "django.db.models") return (path, (), {"name": self.name}) def clone(self): _, args, kwargs = self.deconstruct() return self.__class__(*args, **kwargs) class CheckConstraint(BaseConstraint): def __init__(self, *, check, name): self.check = check if not getattr(check, "conditional", False): raise TypeError( "CheckConstraint.check must be a Q instance or boolean expression." ) super().__init__(name) def _get_check_sql(self, model, schema_editor): query = Query(model=model, alias_cols=False) where = query.build_where(self.check) compiler = query.get_compiler(connection=schema_editor.connection) sql, params = where.as_sql(compiler, schema_editor.connection) return sql % tuple(schema_editor.quote_value(p) for p in params) def constraint_sql(self, model, schema_editor): check = self._get_check_sql(model, schema_editor) return schema_editor._check_sql(self.name, check) def create_sql(self, model, schema_editor): check = self._get_check_sql(model, schema_editor) return schema_editor._create_check_sql(model, self.name, check) def remove_sql(self, model, schema_editor): return schema_editor._delete_check_sql(model, self.name) def __repr__(self): return "<%s: check=%s name=%s>" % ( self.__class__.__qualname__, self.check, repr(self.name), ) def __eq__(self, other): if isinstance(other, CheckConstraint): return self.name == other.name and self.check == other.check return super().__eq__(other) def deconstruct(self): path, args, kwargs = super().deconstruct() kwargs["check"] = self.check return path, args, kwargs class Deferrable(Enum): DEFERRED = "deferred" IMMEDIATE = "immediate" # A similar format was proposed for Python 3.10. def __repr__(self): return f"{self.__class__.__qualname__}.{self._name_}" class UniqueConstraint(BaseConstraint): def __init__( self, *expressions, fields=(), name=None, condition=None, deferrable=None, include=None, opclasses=(), ): if not name: raise ValueError("A unique constraint must be named.") if not expressions and not fields: raise ValueError( "At least one field or expression is required to define a " "unique constraint." ) if expressions and fields: raise ValueError( "UniqueConstraint.fields and expressions are mutually exclusive." ) if not isinstance(condition, (type(None), Q)): raise ValueError("UniqueConstraint.condition must be a Q instance.") if condition and deferrable: raise ValueError("UniqueConstraint with conditions cannot be deferred.") if include and deferrable: raise ValueError("UniqueConstraint with include fields cannot be deferred.") if opclasses and deferrable: raise ValueError("UniqueConstraint with opclasses cannot be deferred.") if expressions and deferrable: raise ValueError("UniqueConstraint with expressions cannot be deferred.") if expressions and opclasses: raise ValueError( "UniqueConstraint.opclasses cannot be used with expressions. " "Use django.contrib.postgres.indexes.OpClass() instead." ) if not isinstance(deferrable, (type(None), Deferrable)): raise ValueError( "UniqueConstraint.deferrable must be a Deferrable instance." ) if not isinstance(include, (type(None), list, tuple)): raise ValueError("UniqueConstraint.include must be a list or tuple.") if not isinstance(opclasses, (list, tuple)): raise ValueError("UniqueConstraint.opclasses must be a list or tuple.") if opclasses and len(fields) != len(opclasses): raise ValueError( "UniqueConstraint.fields and UniqueConstraint.opclasses must " "have the same number of elements." ) self.fields = tuple(fields) self.condition = condition self.deferrable = deferrable self.include = tuple(include) if include else () self.opclasses = opclasses self.expressions = tuple( F(expression) if isinstance(expression, str) else expression for expression in expressions ) super().__init__(name) @property def contains_expressions(self): return bool(self.expressions) def _get_condition_sql(self, model, schema_editor): if self.condition is None: return None query = Query(model=model, alias_cols=False) where = query.build_where(self.condition) compiler = query.get_compiler(connection=schema_editor.connection) sql, params = where.as_sql(compiler, schema_editor.connection) return sql % tuple(schema_editor.quote_value(p) for p in params) def _get_index_expressions(self, model, schema_editor): if not self.expressions: return None index_expressions = [] for expression in self.expressions: index_expression = IndexExpression(expression) index_expression.set_wrapper_classes(schema_editor.connection) index_expressions.append(index_expression) return ExpressionList(*index_expressions).resolve_expression( Query(model, alias_cols=False), ) def constraint_sql(self, model, schema_editor): fields = [model._meta.get_field(field_name) for field_name in self.fields] include = [ model._meta.get_field(field_name).column for field_name in self.include ] condition = self._get_condition_sql(model, schema_editor) expressions = self._get_index_expressions(model, schema_editor) return schema_editor._unique_sql( model, fields, self.name, condition=condition, deferrable=self.deferrable, include=include, opclasses=self.opclasses, expressions=expressions, ) def create_sql(self, model, schema_editor): fields = [model._meta.get_field(field_name) for field_name in self.fields] include = [ model._meta.get_field(field_name).column for field_name in self.include ] condition = self._get_condition_sql(model, schema_editor) expressions = self._get_index_expressions(model, schema_editor) return schema_editor._create_unique_sql( model, fields, self.name, condition=condition, deferrable=self.deferrable, include=include, opclasses=self.opclasses, expressions=expressions, ) def remove_sql(self, model, schema_editor): condition = self._get_condition_sql(model, schema_editor) include = [ model._meta.get_field(field_name).column for field_name in self.include ] expressions = self._get_index_expressions(model, schema_editor) return schema_editor._delete_unique_sql( model, self.name, condition=condition, deferrable=self.deferrable, include=include, opclasses=self.opclasses, expressions=expressions, ) def __repr__(self): return "<%s:%s%s%s%s%s%s%s>" % ( self.__class__.__qualname__, "" if not self.fields else " fields=%s" % repr(self.fields), "" if not self.expressions else " expressions=%s" % repr(self.expressions), " name=%s" % repr(self.name), "" if self.condition is None else " condition=%s" % self.condition, "" if self.deferrable is None else " deferrable=%r" % self.deferrable, "" if not self.include else " include=%s" % repr(self.include), "" if not self.opclasses else " opclasses=%s" % repr(self.opclasses), ) def __eq__(self, other): if isinstance(other, UniqueConstraint): return ( self.name == other.name and self.fields == other.fields and self.condition == other.condition and self.deferrable == other.deferrable and self.include == other.include and self.opclasses == other.opclasses and self.expressions == other.expressions ) return super().__eq__(other) def deconstruct(self): path, args, kwargs = super().deconstruct() if self.fields: kwargs["fields"] = self.fields if self.condition: kwargs["condition"] = self.condition if self.deferrable: kwargs["deferrable"] = self.deferrable if self.include: kwargs["include"] = self.include if self.opclasses: kwargs["opclasses"] = self.opclasses return path, self.expressions, kwargs