From 0837eacc4e1fa7916e48135e8ba43f54a7a64997 Mon Sep 17 00:00:00 2001
From: Michael Manfre <mmanfre@gmail.com>
Date: Wed, 8 Jan 2014 23:31:34 -0500
Subject: [PATCH] Made SQLCompiler.execute_sql(result_type) more explicit.

Updated SQLUpdateCompiler.execute_sql to match the behavior described in
the docstring; the 'first non-empty query' will now include all queries,
not just the main and first related update.

Added CURSOR and NO_RESULTS result_type constants to make the usages more
self documenting and allow execute_sql to explicitly close the cursor when
it is no longer needed.
---
 django/db/models/query.py          |  5 +-
 django/db/models/sql/compiler.py   | 78 ++++++++++++++++++++++--------
 django/db/models/sql/constants.py  |  2 +
 django/db/models/sql/subqueries.py |  8 +--
 tests/backends/tests.py            |  3 +-
 5 files changed, 69 insertions(+), 27 deletions(-)

diff --git a/django/db/models/query.py b/django/db/models/query.py
index 48d295ccca..353dd95794 100644
--- a/django/db/models/query.py
+++ b/django/db/models/query.py
@@ -14,6 +14,7 @@ from django.db.models.fields import AutoField, Empty
 from django.db.models.query_utils import (Q, select_related_descend,
     deferred_class_factory, InvalidQuery)
 from django.db.models.deletion import Collector
+from django.db.models.sql.constants import CURSOR
 from django.db.models import sql
 from django.utils.functional import partition
 from django.utils import six
@@ -574,7 +575,7 @@ class QuerySet(object):
         query = self.query.clone(sql.UpdateQuery)
         query.add_update_values(kwargs)
         with transaction.commit_on_success_unless_managed(using=self.db):
-            rows = query.get_compiler(self.db).execute_sql(None)
+            rows = query.get_compiler(self.db).execute_sql(CURSOR)
         self._result_cache = None
         return rows
     update.alters_data = True
@@ -591,7 +592,7 @@ class QuerySet(object):
         query = self.query.clone(sql.UpdateQuery)
         query.add_update_fields(values)
         self._result_cache = None
-        return query.get_compiler(self.db).execute_sql(None)
+        return query.get_compiler(self.db).execute_sql(CURSOR)
     _update.alters_data = True
     _update.queryset_only = False
 
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 123427cf8b..536a66d139 100644
--- a/django/db/models/sql/compiler.py
+++ b/django/db/models/sql/compiler.py
@@ -5,8 +5,8 @@ from django.core.exceptions import FieldError
 from django.db.backends.utils import truncate_name
 from django.db.models.constants import LOOKUP_SEP
 from django.db.models.query_utils import select_related_descend, QueryWrapper
-from django.db.models.sql.constants import (SINGLE, MULTI, ORDER_DIR,
-        GET_ITERATOR_CHUNK_SIZE, SelectInfo)
+from django.db.models.sql.constants import (CURSOR, SINGLE, MULTI, NO_RESULTS,
+        ORDER_DIR, GET_ITERATOR_CHUNK_SIZE, SelectInfo)
 from django.db.models.sql.datastructures import EmptyResultSet
 from django.db.models.sql.expressions import SQLEvaluator
 from django.db.models.sql.query import get_order_dir, Query
@@ -762,6 +762,8 @@ class SQLCompiler(object):
         is needed, as the filters describe an empty set. In that case, None is
         returned, to avoid any unnecessary database interaction.
         """
+        if not result_type:
+            result_type = NO_RESULTS
         try:
             sql, params = self.as_sql()
             if not sql:
@@ -773,27 +775,44 @@ class SQLCompiler(object):
                 return
 
         cursor = self.connection.cursor()
-        cursor.execute(sql, params)
+        try:
+            cursor.execute(sql, params)
+        except:
+            cursor.close()
+            raise
 
-        if not result_type:
+        if result_type == CURSOR:
+            # Caller didn't specify a result_type, so just give them back the
+            # cursor to process (and close).
             return cursor
         if result_type == SINGLE:
-            if self.ordering_aliases:
-                return cursor.fetchone()[:-len(self.ordering_aliases)]
-            return cursor.fetchone()
+            try:
+                if self.ordering_aliases:
+                    return cursor.fetchone()[:-len(self.ordering_aliases)]
+                return cursor.fetchone()
+            finally:
+                # done with the cursor
+                cursor.close()
+        if result_type == NO_RESULTS:
+            cursor.close()
+            return
 
         # The MULTI case.
         if self.ordering_aliases:
             result = order_modified_iter(cursor, len(self.ordering_aliases),
                     self.connection.features.empty_fetchmany_value)
         else:
-            result = iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
-                    self.connection.features.empty_fetchmany_value)
+            result = cursor_iter(cursor,
+                self.connection.features.empty_fetchmany_value)
         if not self.connection.features.can_use_chunked_reads:
-            # If we are using non-chunked reads, we return the same data
-            # structure as normally, but ensure it is all read into memory
-            # before going any further.
-            return list(result)
+            try:
+                # If we are using non-chunked reads, we return the same data
+                # structure as normally, but ensure it is all read into memory
+                # before going any further.
+                return list(result)
+            finally:
+                # done with the cursor
+                cursor.close()
         return result
 
     def as_subquery_condition(self, alias, columns, qn):
@@ -970,12 +989,15 @@ class SQLUpdateCompiler(SQLCompiler):
         related queries are not available.
         """
         cursor = super(SQLUpdateCompiler, self).execute_sql(result_type)
-        rows = cursor.rowcount if cursor else 0
-        is_empty = cursor is None
-        del cursor
+        try:
+            rows = cursor.rowcount if cursor else 0
+            is_empty = cursor is None
+        finally:
+            if cursor:
+                cursor.close()
         for query in self.query.get_related_updates():
             aux_rows = query.get_compiler(self.using).execute_sql(result_type)
-            if is_empty:
+            if is_empty and aux_rows:
                 rows = aux_rows
                 is_empty = False
         return rows
@@ -1111,6 +1133,19 @@ class SQLDateTimeCompiler(SQLCompiler):
                 yield datetime
 
 
+def cursor_iter(cursor, sentinel):
+    """
+    Yields blocks of rows from a cursor and ensures the cursor is closed when
+    done.
+    """
+    try:
+        for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
+                sentinel):
+            yield rows
+    finally:
+        cursor.close()
+
+
 def order_modified_iter(cursor, trim, sentinel):
     """
     Yields blocks of rows from a cursor. We use this iterator in the special
@@ -1118,6 +1153,9 @@ def order_modified_iter(cursor, trim, sentinel):
     requirements. We must trim those extra columns before anything else can use
     the results, since they're only needed to make the SQL valid.
     """
-    for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
-            sentinel):
-        yield [r[:-trim] for r in rows]
+    try:
+        for rows in iter((lambda: cursor.fetchmany(GET_ITERATOR_CHUNK_SIZE)),
+                sentinel):
+            yield [r[:-trim] for r in rows]
+    finally:
+        cursor.close()
diff --git a/django/db/models/sql/constants.py b/django/db/models/sql/constants.py
index 904f7b2c8b..36aab23bae 100644
--- a/django/db/models/sql/constants.py
+++ b/django/db/models/sql/constants.py
@@ -33,6 +33,8 @@ SelectInfo = namedtuple('SelectInfo', 'col field')
 # How many results to expect from a cursor.execute call
 MULTI = 'multi'
 SINGLE = 'single'
+CURSOR = 'cursor'
+NO_RESULTS = 'no results'
 
 ORDER_PATTERN = re.compile(r'\?|[-+]?[.\w]+$')
 ORDER_DIR = {
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
index 86b1efd3f8..cfda1f552c 100644
--- a/django/db/models/sql/subqueries.py
+++ b/django/db/models/sql/subqueries.py
@@ -8,7 +8,7 @@ from django.db import connections
 from django.db.models.query_utils import Q
 from django.db.models.constants import LOOKUP_SEP
 from django.db.models.fields import DateField, DateTimeField, FieldDoesNotExist
-from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, SelectInfo
+from django.db.models.sql.constants import GET_ITERATOR_CHUNK_SIZE, NO_RESULTS, SelectInfo
 from django.db.models.sql.datastructures import Date, DateTime
 from django.db.models.sql.query import Query
 from django.utils import six
@@ -30,7 +30,7 @@ class DeleteQuery(Query):
     def do_query(self, table, where, using):
         self.tables = [table]
         self.where = where
-        self.get_compiler(using).execute_sql(None)
+        self.get_compiler(using).execute_sql(NO_RESULTS)
 
     def delete_batch(self, pk_list, using, field=None):
         """
@@ -82,7 +82,7 @@ class DeleteQuery(Query):
                 values = innerq
             self.where = self.where_class()
             self.add_q(Q(pk__in=values))
-        self.get_compiler(using).execute_sql(None)
+        self.get_compiler(using).execute_sql(NO_RESULTS)
 
 
 class UpdateQuery(Query):
@@ -116,7 +116,7 @@ class UpdateQuery(Query):
         for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
             self.where = self.where_class()
             self.add_q(Q(pk__in=pk_list[offset: offset + GET_ITERATOR_CHUNK_SIZE]))
-            self.get_compiler(using).execute_sql(None)
+            self.get_compiler(using).execute_sql(NO_RESULTS)
 
     def add_update_values(self, values):
         """
diff --git a/tests/backends/tests.py b/tests/backends/tests.py
index 0ff3ad0bba..4a3fc31b7a 100644
--- a/tests/backends/tests.py
+++ b/tests/backends/tests.py
@@ -20,6 +20,7 @@ from django.db.backends.utils import format_number, CursorWrapper
 from django.db.models import Sum, Avg, Variance, StdDev
 from django.db.models.fields import (AutoField, DateField, DateTimeField,
     DecimalField, IntegerField, TimeField)
+from django.db.models.sql.constants import CURSOR
 from django.db.utils import ConnectionHandler
 from django.test import (TestCase, TransactionTestCase, override_settings,
     skipUnlessDBFeature, skipIfDBFeature)
@@ -209,7 +210,7 @@ class LastExecutedQueryTest(TestCase):
         """
         persons = models.Reporter.objects.filter(raw_data=b'\x00\x46  \xFE').extra(select={'föö': 1})
         sql, params = persons.query.sql_with_params()
-        cursor = persons.query.get_compiler('default').execute_sql(None)
+        cursor = persons.query.get_compiler('default').execute_sql(CURSOR)
         last_sql = cursor.db.ops.last_executed_query(cursor, sql, params)
         self.assertIsInstance(last_sql, six.text_type)