1
0
mirror of https://github.com/django/django.git synced 2025-10-23 21:59:11 +00:00

Refs #30054, #20483 -- Cached SQLite references graph retrieval on sql_flush().

This commit is contained in:
Simon Charette
2018-12-25 12:23:08 -05:00
committed by Tim Graham
parent 7a6dbbb655
commit 2b2ae4eeb7

View File

@@ -1,6 +1,8 @@
import datetime
import decimal
import uuid
from functools import lru_cache
from itertools import chain
from django.conf import settings
from django.core.exceptions import FieldError
@@ -11,6 +13,7 @@ from django.db.models.expressions import Col
from django.utils import timezone
from django.utils.dateparse import parse_date, parse_datetime, parse_time
from django.utils.duration import duration_microseconds
from django.utils.functional import cached_property
class DatabaseOperations(BaseDatabaseOperations):
@@ -158,28 +161,36 @@ class DatabaseOperations(BaseDatabaseOperations):
def no_limit_value(self):
return -1
def __references_graph(self, table_name):
query = """
WITH tables AS (
SELECT %s name
UNION
SELECT sqlite_master.name
FROM sqlite_master
JOIN tables ON (sql REGEXP %s || tables.name || %s)
) SELECT name FROM tables;
"""
params = (
table_name,
r'(?i)\s+references\s+("|\')?',
r'("|\')?\s*\(',
)
with self.connection.cursor() as cursor:
results = cursor.execute(query, params)
return [row[0] for row in results.fetchall()]
@cached_property
def _references_graph(self):
# 512 is large enough to fit the ~330 tables (as of this writing) in
# Django's test suite.
return lru_cache(maxsize=512)(self.__references_graph)
def sql_flush(self, style, tables, sequences, allow_cascade=False):
if tables and allow_cascade:
# Simulate TRUNCATE CASCADE by recursively collecting the tables
# referencing the tables to be flushed.
query = """
WITH tables AS (
%s
UNION
SELECT sqlite_master.name
FROM sqlite_master
JOIN tables ON (
sql REGEXP %%s || tables.name || %%s
)
) SELECT name FROM tables;
""" % ' UNION '.join("SELECT '%s' name" % table for table in tables)
params = (
r'(?i)\s+references\s+("|\')?',
r'("|\')?\s*\(',
)
with self.connection.cursor() as cursor:
results = cursor.execute(query, params)
tables = [row[0] for row in results.fetchall()]
tables = set(chain.from_iterable(self._references_graph(table) for table in tables))
sql = ['%s %s %s;' % (
style.SQL_KEYWORD('DELETE'),
style.SQL_KEYWORD('FROM'),