From abd0ad7681422d7c40a5ed12cc3c9ffca6b88422 Mon Sep 17 00:00:00 2001 From: oliver Date: Thu, 2 Aug 2018 14:20:46 +0900 Subject: [PATCH] Fixed #29626, #29584 -- Added optimized versions of get_many() and delete_many() for the db cache backend. --- django/core/cache/backends/db.py | 86 ++++++++++++++++++++------------ tests/cache/tests.py | 14 ++++++ 2 files changed, 67 insertions(+), 33 deletions(-) diff --git a/django/core/cache/backends/db.py b/django/core/cache/backends/db.py index 76aff9c582..21b5aa88ad 100644 --- a/django/core/cache/backends/db.py +++ b/django/core/cache/backends/db.py @@ -49,8 +49,17 @@ class DatabaseCache(BaseDatabaseCache): pickle_protocol = pickle.HIGHEST_PROTOCOL def get(self, key, default=None, version=None): - key = self.make_key(key, version=version) - self.validate_key(key) + return self.get_many([key], version).get(key, default) + + def get_many(self, keys, version=None): + if not keys: + return {} + + key_map = {} + for key in keys: + self.validate_key(key) + key_map[self.make_key(key, version)] = key + db = router.db_for_read(self.cache_model_class) connection = connections[db] quote_name = connection.ops.quote_name @@ -58,43 +67,36 @@ class DatabaseCache(BaseDatabaseCache): with connection.cursor() as cursor: cursor.execute( - 'SELECT %s, %s, %s FROM %s WHERE %s = %%s' % ( + 'SELECT %s, %s, %s FROM %s WHERE %s IN (%s)' % ( quote_name('cache_key'), quote_name('value'), quote_name('expires'), table, quote_name('cache_key'), + ', '.join(['%s'] * len(key_map)), ), - [key] + list(key_map), ) - row = cursor.fetchone() - if row is None: - return default + rows = cursor.fetchall() - expires = row[2] + result = {} + expired_keys = [] expression = models.Expression(output_field=models.DateTimeField()) - for converter in (connection.ops.get_db_converters(expression) + - expression.get_db_converters(connection)): - if func_supports_parameter(converter, 'context'): # RemovedInDjango30Warning - expires = converter(expires, expression, connection, {}) + converters = (connection.ops.get_db_converters(expression) + expression.get_db_converters(connection)) + for key, value, expires in rows: + for converter in converters: + if func_supports_parameter(converter, 'context'): # RemovedInDjango30Warning + expires = converter(expires, expression, connection, {}) + else: + expires = converter(expires, expression, connection) + if expires < timezone.now(): + expired_keys.append(key) else: - expires = converter(expires, expression, connection) - - if expires < timezone.now(): - db = router.db_for_write(self.cache_model_class) - connection = connections[db] - with connection.cursor() as cursor: - cursor.execute( - 'DELETE FROM %s WHERE %s = %%s' % ( - table, - quote_name('cache_key'), - ), - [key] - ) - return default - - value = connection.ops.process_clob(row[1]) - return pickle.loads(base64.b64decode(value.encode())) + value = connection.ops.process_clob(value) + value = pickle.loads(base64.b64decode(value.encode())) + result[key_map.get(key)] = value + self._base_delete_many(expired_keys) + return result def set(self, key, value, timeout=DEFAULT_TIMEOUT, version=None): key = self.make_key(key, version=version) @@ -202,15 +204,33 @@ class DatabaseCache(BaseDatabaseCache): return True def delete(self, key, version=None): - key = self.make_key(key, version=version) - self.validate_key(key) + self.delete_many([key], version) + + def delete_many(self, keys, version=None): + key_list = [] + for key in keys: + self.validate_key(key) + key_list.append(self.make_key(key, version)) + self._base_delete_many(key_list) + + def _base_delete_many(self, keys): + if not keys: + return db = router.db_for_write(self.cache_model_class) connection = connections[db] - table = connection.ops.quote_name(self._table) + quote_name = connection.ops.quote_name + table = quote_name(self._table) with connection.cursor() as cursor: - cursor.execute("DELETE FROM %s WHERE cache_key = %%s" % table, [key]) + cursor.execute( + 'DELETE FROM %s WHERE %s IN (%s)' % ( + table, + quote_name('cache_key'), + ', '.join(['%s'] * len(keys)), + ), + keys, + ) def has_key(self, key, version=None): key = self.make_key(key, version=version) diff --git a/tests/cache/tests.py b/tests/cache/tests.py index a101639f49..6578eb288f 100644 --- a/tests/cache/tests.py +++ b/tests/cache/tests.py @@ -1005,6 +1005,20 @@ class DBCacheTests(BaseCacheTests, TransactionTestCase): table_name = connection.ops.quote_name('test cache table') cursor.execute('DROP TABLE %s' % table_name) + def test_get_many_num_queries(self): + cache.set_many({'a': 1, 'b': 2}) + cache.set('expired', 'expired', 0.01) + with self.assertNumQueries(1): + self.assertEqual(cache.get_many(['a', 'b']), {'a': 1, 'b': 2}) + time.sleep(0.02) + with self.assertNumQueries(2): + self.assertEqual(cache.get_many(['a', 'b', 'expired']), {'a': 1, 'b': 2}) + + def test_delete_many_num_queries(self): + cache.set_many({'a': 1, 'b': 2, 'c': 3}) + with self.assertNumQueries(1): + cache.delete_many(['a', 'b', 'c']) + def test_zero_cull(self): self._perform_cull_test(caches['zero_cull'], 50, 18)