import re
from collections import namedtuple

import sqlparse

from django.db.backends.base.introspection import BaseDatabaseIntrospection
from django.db.backends.base.introspection import FieldInfo as BaseFieldInfo
from django.db.backends.base.introspection import TableInfo
from django.db.models import Index
from django.utils.regex_helper import _lazy_re_compile

FieldInfo = namedtuple(
    "FieldInfo", BaseFieldInfo._fields + ("pk", "has_json_constraint")
)

field_size_re = _lazy_re_compile(r"^\s*(?:var)?char\s*\(\s*(\d+)\s*\)\s*$")


def get_field_size(name):
    """Extract the size number from a "varchar(11)" type name"""
    m = field_size_re.search(name)
    return int(m[1]) if m else None


# This light wrapper "fakes" a dictionary interface, because some SQLite data
# types include variables in them -- e.g. "varchar(30)" -- and can't be matched
# as a simple dictionary lookup.
class FlexibleFieldLookupDict:
    # Maps SQL types to Django Field types. Some of the SQL types have multiple
    # entries here because SQLite allows for anything and doesn't normalize the
    # field type; it uses whatever was given.
    base_data_types_reverse = {
        "bool": "BooleanField",
        "boolean": "BooleanField",
        "smallint": "SmallIntegerField",
        "smallint unsigned": "PositiveSmallIntegerField",
        "smallinteger": "SmallIntegerField",
        "int": "IntegerField",
        "integer": "IntegerField",
        "bigint": "BigIntegerField",
        "integer unsigned": "PositiveIntegerField",
        "bigint unsigned": "PositiveBigIntegerField",
        "decimal": "DecimalField",
        "real": "FloatField",
        "text": "TextField",
        "char": "CharField",
        "varchar": "CharField",
        "blob": "BinaryField",
        "date": "DateField",
        "datetime": "DateTimeField",
        "time": "TimeField",
    }

    def __getitem__(self, key):
        key = key.lower().split("(", 1)[0].strip()
        return self.base_data_types_reverse[key]


class DatabaseIntrospection(BaseDatabaseIntrospection):
    data_types_reverse = FlexibleFieldLookupDict()

    def get_field_type(self, data_type, description):
        field_type = super().get_field_type(data_type, description)
        if description.pk and field_type in {
            "BigIntegerField",
            "IntegerField",
            "SmallIntegerField",
        }:
            # No support for BigAutoField or SmallAutoField as SQLite treats
            # all integer primary keys as signed 64-bit integers.
            return "AutoField"
        if description.has_json_constraint:
            return "JSONField"
        return field_type

    def get_table_list(self, cursor):
        """Return a list of table and view names in the current database."""
        # Skip the sqlite_sequence system table used for autoincrement key
        # generation.
        cursor.execute(
            """
            SELECT name, type FROM sqlite_master
            WHERE type in ('table', 'view') AND NOT name='sqlite_sequence'
            ORDER BY name"""
        )
        return [TableInfo(row[0], row[1][0]) for row in cursor.fetchall()]

    def get_table_description(self, cursor, table_name):
        """
        Return a description of the table with the DB-API cursor.description
        interface.
        """
        cursor.execute(
            "PRAGMA table_info(%s)" % self.connection.ops.quote_name(table_name)
        )
        table_info = cursor.fetchall()
        collations = self._get_column_collations(cursor, table_name)
        json_columns = set()
        if self.connection.features.can_introspect_json_field:
            for line in table_info:
                column = line[1]
                json_constraint_sql = '%%json_valid("%s")%%' % column
                has_json_constraint = cursor.execute(
                    """
                    SELECT sql
                    FROM sqlite_master
                    WHERE
                        type = 'table' AND
                        name = %s AND
                        sql LIKE %s
                """,
                    [table_name, json_constraint_sql],
                ).fetchone()
                if has_json_constraint:
                    json_columns.add(column)
        return [
            FieldInfo(
                name,
                data_type,
                None,
                get_field_size(data_type),
                None,
                None,
                not notnull,
                default,
                collations.get(name),
                pk == 1,
                name in json_columns,
            )
            for cid, name, data_type, notnull, default, pk in table_info
        ]

    def get_sequences(self, cursor, table_name, table_fields=()):
        pk_col = self.get_primary_key_column(cursor, table_name)
        return [{"table": table_name, "column": pk_col}]

    def get_relations(self, cursor, table_name):
        """
        Return a dictionary of {field_name: (field_name_other_table, other_table)}
        representing all relationships to the given table.
        """
        # Dictionary of relations to return
        relations = {}

        # Schema for this table
        cursor.execute(
            "SELECT sql, type FROM sqlite_master "
            "WHERE tbl_name = %s AND type IN ('table', 'view')",
            [table_name],
        )
        create_sql, table_type = cursor.fetchone()
        if table_type == "view":
            # It might be a view, then no results will be returned
            return relations
        results = create_sql[create_sql.index("(") + 1 : create_sql.rindex(")")]

        # Walk through and look for references to other tables. SQLite doesn't
        # really have enforced references, but since it echoes out the SQL used
        # to create the table we can look for REFERENCES statements used there.
        for field_desc in results.split(","):
            field_desc = field_desc.strip()
            if field_desc.startswith("UNIQUE"):
                continue

            m = re.search(r'references (\S*) ?\(["|]?(.*)["|]?\)', field_desc, re.I)
            if not m:
                continue
            table, column = [s.strip('"') for s in m.groups()]

            if field_desc.startswith("FOREIGN KEY"):
                # Find name of the target FK field
                m = re.match(r"FOREIGN KEY\s*\(([^\)]*)\).*", field_desc, re.I)
                field_name = m[1].strip('"')
            else:
                field_name = field_desc.split()[0].strip('"')

            cursor.execute("SELECT sql FROM sqlite_master WHERE tbl_name = %s", [table])
            result = cursor.fetchall()[0]
            other_table_results = result[0].strip()
            li, ri = other_table_results.index("("), other_table_results.rindex(")")
            other_table_results = other_table_results[li + 1 : ri]

            for other_desc in other_table_results.split(","):
                other_desc = other_desc.strip()
                if other_desc.startswith("UNIQUE"):
                    continue

                other_name = other_desc.split(" ", 1)[0].strip('"')
                if other_name == column:
                    relations[field_name] = (other_name, table)
                    break

        return relations

    def get_key_columns(self, cursor, table_name):
        """
        Return a list of (column_name, referenced_table_name, referenced_column_name)
        for all key columns in given table.
        """
        key_columns = []

        # Schema for this table
        cursor.execute(
            "SELECT sql FROM sqlite_master WHERE tbl_name = %s AND type = %s",
            [table_name, "table"],
        )
        results = cursor.fetchone()[0].strip()
        results = results[results.index("(") + 1 : results.rindex(")")]

        # Walk through and look for references to other tables. SQLite doesn't
        # really have enforced references, but since it echoes out the SQL used
        # to create the table we can look for REFERENCES statements used there.
        for field_index, field_desc in enumerate(results.split(",")):
            field_desc = field_desc.strip()
            if field_desc.startswith("UNIQUE"):
                continue

            m = re.search(r'"(.*)".*references (.*) \(["|](.*)["|]\)', field_desc, re.I)
            if not m:
                continue

            # This will append
            # (column_name, referenced_table_name, referenced_column_name) to
            # key_columns.
            key_columns.append(tuple(s.strip('"') for s in m.groups()))

        return key_columns

    def get_primary_key_column(self, cursor, table_name):
        """Return the column name of the primary key for the given table."""
        # Don't use PRAGMA because that causes issues with some transactions
        cursor.execute(
            "SELECT sql, type FROM sqlite_master "
            "WHERE tbl_name = %s AND type IN ('table', 'view')",
            [table_name],
        )
        row = cursor.fetchone()
        if row is None:
            raise ValueError("Table %s does not exist" % table_name)
        create_sql, table_type = row
        if table_type == "view":
            # Views don't have a primary key.
            return None
        fields_sql = create_sql[create_sql.index("(") + 1 : create_sql.rindex(")")]
        for field_desc in fields_sql.split(","):
            field_desc = field_desc.strip()
            m = re.match(
                r'(?:(?:["`\[])(.*)(?:["`\]])|(\w+)).*PRIMARY KEY.*', field_desc
            )
            if m:
                return m[1] if m[1] else m[2]
        return None

    def _get_foreign_key_constraints(self, cursor, table_name):
        constraints = {}
        cursor.execute(
            "PRAGMA foreign_key_list(%s)" % self.connection.ops.quote_name(table_name)
        )
        for row in cursor.fetchall():
            # Remaining on_update/on_delete/match values are of no interest.
            id_, _, table, from_, to = row[:5]
            constraints["fk_%d" % id_] = {
                "columns": [from_],
                "primary_key": False,
                "unique": False,
                "foreign_key": (table, to),
                "check": False,
                "index": False,
            }
        return constraints

    def _parse_column_or_constraint_definition(self, tokens, columns):
        token = None
        is_constraint_definition = None
        field_name = None
        constraint_name = None
        unique = False
        unique_columns = []
        check = False
        check_columns = []
        braces_deep = 0
        for token in tokens:
            if token.match(sqlparse.tokens.Punctuation, "("):
                braces_deep += 1
            elif token.match(sqlparse.tokens.Punctuation, ")"):
                braces_deep -= 1
                if braces_deep < 0:
                    # End of columns and constraints for table definition.
                    break
            elif braces_deep == 0 and token.match(sqlparse.tokens.Punctuation, ","):
                # End of current column or constraint definition.
                break
            # Detect column or constraint definition by first token.
            if is_constraint_definition is None:
                is_constraint_definition = token.match(
                    sqlparse.tokens.Keyword, "CONSTRAINT"
                )
                if is_constraint_definition:
                    continue
            if is_constraint_definition:
                # Detect constraint name by second token.
                if constraint_name is None:
                    if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
                        constraint_name = token.value
                    elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
                        constraint_name = token.value[1:-1]
                # Start constraint columns parsing after UNIQUE keyword.
                if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
                    unique = True
                    unique_braces_deep = braces_deep
                elif unique:
                    if unique_braces_deep == braces_deep:
                        if unique_columns:
                            # Stop constraint parsing.
                            unique = False
                        continue
                    if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
                        unique_columns.append(token.value)
                    elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
                        unique_columns.append(token.value[1:-1])
            else:
                # Detect field name by first token.
                if field_name is None:
                    if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
                        field_name = token.value
                    elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
                        field_name = token.value[1:-1]
                if token.match(sqlparse.tokens.Keyword, "UNIQUE"):
                    unique_columns = [field_name]
            # Start constraint columns parsing after CHECK keyword.
            if token.match(sqlparse.tokens.Keyword, "CHECK"):
                check = True
                check_braces_deep = braces_deep
            elif check:
                if check_braces_deep == braces_deep:
                    if check_columns:
                        # Stop constraint parsing.
                        check = False
                    continue
                if token.ttype in (sqlparse.tokens.Name, sqlparse.tokens.Keyword):
                    if token.value in columns:
                        check_columns.append(token.value)
                elif token.ttype == sqlparse.tokens.Literal.String.Symbol:
                    if token.value[1:-1] in columns:
                        check_columns.append(token.value[1:-1])
        unique_constraint = (
            {
                "unique": True,
                "columns": unique_columns,
                "primary_key": False,
                "foreign_key": None,
                "check": False,
                "index": False,
            }
            if unique_columns
            else None
        )
        check_constraint = (
            {
                "check": True,
                "columns": check_columns,
                "primary_key": False,
                "unique": False,
                "foreign_key": None,
                "index": False,
            }
            if check_columns
            else None
        )
        return constraint_name, unique_constraint, check_constraint, token

    def _parse_table_constraints(self, sql, columns):
        # Check constraint parsing is based of SQLite syntax diagram.
        # https://www.sqlite.org/syntaxdiagrams.html#table-constraint
        statement = sqlparse.parse(sql)[0]
        constraints = {}
        unnamed_constrains_index = 0
        tokens = (token for token in statement.flatten() if not token.is_whitespace)
        # Go to columns and constraint definition
        for token in tokens:
            if token.match(sqlparse.tokens.Punctuation, "("):
                break
        # Parse columns and constraint definition
        while True:
            (
                constraint_name,
                unique,
                check,
                end_token,
            ) = self._parse_column_or_constraint_definition(tokens, columns)
            if unique:
                if constraint_name:
                    constraints[constraint_name] = unique
                else:
                    unnamed_constrains_index += 1
                    constraints[
                        "__unnamed_constraint_%s__" % unnamed_constrains_index
                    ] = unique
            if check:
                if constraint_name:
                    constraints[constraint_name] = check
                else:
                    unnamed_constrains_index += 1
                    constraints[
                        "__unnamed_constraint_%s__" % unnamed_constrains_index
                    ] = check
            if end_token.match(sqlparse.tokens.Punctuation, ")"):
                break
        return constraints

    def get_constraints(self, cursor, table_name):
        """
        Retrieve any constraints or keys (unique, pk, fk, check, index) across
        one or more columns.
        """
        constraints = {}
        # Find inline check constraints.
        try:
            table_schema = cursor.execute(
                "SELECT sql FROM sqlite_master WHERE type='table' and name=%s"
                % (self.connection.ops.quote_name(table_name),)
            ).fetchone()[0]
        except TypeError:
            # table_name is a view.
            pass
        else:
            columns = {
                info.name for info in self.get_table_description(cursor, table_name)
            }
            constraints.update(self._parse_table_constraints(table_schema, columns))

        # Get the index info
        cursor.execute(
            "PRAGMA index_list(%s)" % self.connection.ops.quote_name(table_name)
        )
        for row in cursor.fetchall():
            # SQLite 3.8.9+ has 5 columns, however older versions only give 3
            # columns. Discard last 2 columns if there.
            number, index, unique = row[:3]
            cursor.execute(
                "SELECT sql FROM sqlite_master "
                "WHERE type='index' AND name=%s" % self.connection.ops.quote_name(index)
            )
            # There's at most one row.
            (sql,) = cursor.fetchone() or (None,)
            # Inline constraints are already detected in
            # _parse_table_constraints(). The reasons to avoid fetching inline
            # constraints from `PRAGMA index_list` are:
            # - Inline constraints can have a different name and information
            #   than what `PRAGMA index_list` gives.
            # - Not all inline constraints may appear in `PRAGMA index_list`.
            if not sql:
                # An inline constraint
                continue
            # Get the index info for that index
            cursor.execute(
                "PRAGMA index_info(%s)" % self.connection.ops.quote_name(index)
            )
            for index_rank, column_rank, column in cursor.fetchall():
                if index not in constraints:
                    constraints[index] = {
                        "columns": [],
                        "primary_key": False,
                        "unique": bool(unique),
                        "foreign_key": None,
                        "check": False,
                        "index": True,
                    }
                constraints[index]["columns"].append(column)
            # Add type and column orders for indexes
            if constraints[index]["index"]:
                # SQLite doesn't support any index type other than b-tree
                constraints[index]["type"] = Index.suffix
                orders = self._get_index_columns_orders(sql)
                if orders is not None:
                    constraints[index]["orders"] = orders
        # Get the PK
        pk_column = self.get_primary_key_column(cursor, table_name)
        if pk_column:
            # SQLite doesn't actually give a name to the PK constraint,
            # so we invent one. This is fine, as the SQLite backend never
            # deletes PK constraints by name, as you can't delete constraints
            # in SQLite; we remake the table with a new PK instead.
            constraints["__primary__"] = {
                "columns": [pk_column],
                "primary_key": True,
                "unique": False,  # It's not actually a unique constraint.
                "foreign_key": None,
                "check": False,
                "index": False,
            }
        constraints.update(self._get_foreign_key_constraints(cursor, table_name))
        return constraints

    def _get_index_columns_orders(self, sql):
        tokens = sqlparse.parse(sql)[0]
        for token in tokens:
            if isinstance(token, sqlparse.sql.Parenthesis):
                columns = str(token).strip("()").split(", ")
                return ["DESC" if info.endswith("DESC") else "ASC" for info in columns]
        return None

    def _get_column_collations(self, cursor, table_name):
        row = cursor.execute(
            """
            SELECT sql
            FROM sqlite_master
            WHERE type = 'table' AND name = %s
        """,
            [table_name],
        ).fetchone()
        if not row:
            return {}

        sql = row[0]
        columns = str(sqlparse.parse(sql)[0][-1]).strip("()").split(", ")
        collations = {}
        for column in columns:
            tokens = column[1:].split()
            column_name = tokens[0].strip('"')
            for index, token in enumerate(tokens):
                if token == "COLLATE":
                    collation = tokens[index + 1]
                    break
            else:
                collation = None
            collations[column_name] = collation
        return collations