From dc27f3ee0c3eb9bb17d6cb764788eeaf73a371d7 Mon Sep 17 00:00:00 2001 From: Simon Charette Date: Thu, 26 Mar 2015 16:54:43 -0400 Subject: [PATCH] Fixed #19259 -- Added group by selected primary keys support. --- django/db/backends/base/features.py | 1 + .../db/backends/postgresql_psycopg2/features.py | 1 + django/db/models/sql/compiler.py | 16 ++++++++++++---- django/test/__init__.py | 6 +++--- django/test/testcases.py | 10 ++++++++++ tests/aggregation_regress/tests.py | 12 +++++++----- 6 files changed, 34 insertions(+), 12 deletions(-) diff --git a/django/db/backends/base/features.py b/django/db/backends/base/features.py index d48a773aa7..d9463ddf19 100644 --- a/django/db/backends/base/features.py +++ b/django/db/backends/base/features.py @@ -6,6 +6,7 @@ from django.utils.functional import cached_property class BaseDatabaseFeatures(object): gis_enabled = False allows_group_by_pk = False + allows_group_by_selected_pks = False # True if django.db.backends.utils.typecast_timestamp is used on values # returned from dates() calls. needs_datetime_string_cast = True diff --git a/django/db/backends/postgresql_psycopg2/features.py b/django/db/backends/postgresql_psycopg2/features.py index 6bb6de1a96..789e0e0ccf 100644 --- a/django/db/backends/postgresql_psycopg2/features.py +++ b/django/db/backends/postgresql_psycopg2/features.py @@ -3,6 +3,7 @@ from django.db.utils import InterfaceError class DatabaseFeatures(BaseDatabaseFeatures): + allows_group_by_selected_pks = True needs_datetime_string_cast = False can_return_id_from_insert = True has_real_datatype = True diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 002bc05824..7ea9ef9067 100644 --- a/django/db/models/sql/compiler.py +++ b/django/db/models/sql/compiler.py @@ -136,10 +136,7 @@ class SQLCompiler(object): # If the DB can group by primary key, then group by the primary key of # query's main model. Note that for PostgreSQL the GROUP BY clause must # include the primary key of every table, but for MySQL it is enough to - # have the main table's primary key. Currently only the MySQL form is - # implemented. - # MySQLism: however, columns in HAVING clause must be added to the - # GROUP BY. + # have the main table's primary key. if self.connection.features.allows_group_by_pk: # The logic here is: if the main model's primary key is in the # query, then set new_expressions to that field. If that happens, @@ -150,7 +147,18 @@ class SQLCompiler(object): getattr(expr.output_field, 'model') == self.query.model): pk = expr if pk: + # MySQLism: Columns in HAVING clause must be added to the GROUP BY. expressions = [pk] + [expr for expr in expressions if expr in having] + elif self.connection.features.allows_group_by_selected_pks: + # Filter out all expressions associated with a table's primary key + # present in the grouped columns. This is done by identifying all + # tables that have their primary key included in the grouped + # columns and removing non-primary key columns referring to them. + pks = {expr for expr in expressions if hasattr(expr, 'target') and expr.target.primary_key} + aliases = {expr.alias for expr in pks} + expressions = [ + expr for expr in expressions if expr in pks or getattr(expr, 'alias', None) not in aliases + ] return expressions def get_select(self): diff --git a/django/test/__init__.py b/django/test/__init__.py index f5213301d3..884b1c5ee3 100644 --- a/django/test/__init__.py +++ b/django/test/__init__.py @@ -6,7 +6,7 @@ from django.test.client import Client, RequestFactory from django.test.testcases import ( TestCase, TransactionTestCase, SimpleTestCase, LiveServerTestCase, skipIfDBFeature, - skipUnlessDBFeature + skipUnlessAnyDBFeature, skipUnlessDBFeature ) from django.test.utils import (ignore_warnings, modify_settings, override_settings, override_system_checks) @@ -14,8 +14,8 @@ from django.test.utils import (ignore_warnings, modify_settings, __all__ = [ 'Client', 'RequestFactory', 'TestCase', 'TransactionTestCase', 'SimpleTestCase', 'LiveServerTestCase', 'skipIfDBFeature', - 'skipUnlessDBFeature', 'ignore_warnings', 'modify_settings', - 'override_settings', 'override_system_checks' + 'skipUnlessAnyDBFeature', 'skipUnlessDBFeature', 'ignore_warnings', + 'modify_settings', 'override_settings', 'override_system_checks' ] # To simplify Django's test suite; not meant as a public API diff --git a/django/test/testcases.py b/django/test/testcases.py index 461ab79144..cff2f4411e 100644 --- a/django/test/testcases.py +++ b/django/test/testcases.py @@ -1051,6 +1051,16 @@ def skipUnlessDBFeature(*features): ) +def skipUnlessAnyDBFeature(*features): + """ + Skip a test unless a database has any of the named features. + """ + return _deferredSkip( + lambda: not any(getattr(connection.features, feature, False) for feature in features), + "Database doesn't support any of the feature(s): %s" % ", ".join(features) + ) + + class QuietWSGIRequestHandler(WSGIRequestHandler): """ Just a regular WSGIRequestHandler except it doesn't log to the standard diff --git a/tests/aggregation_regress/tests.py b/tests/aggregation_regress/tests.py index bfd202fbc9..596f69c2dc 100644 --- a/tests/aggregation_regress/tests.py +++ b/tests/aggregation_regress/tests.py @@ -7,10 +7,11 @@ from operator import attrgetter from django.contrib.contenttypes.models import ContentType from django.core.exceptions import FieldError +from django.db import connection from django.db.models import ( F, Q, Avg, Count, Max, StdDev, Sum, Value, Variance, ) -from django.test import TestCase, skipUnlessDBFeature +from django.test import TestCase, skipUnlessAnyDBFeature, skipUnlessDBFeature from django.test.utils import Approximate from django.utils import six @@ -1011,7 +1012,7 @@ class AggregationTests(TestCase): # Check that the query executes without problems. self.assertEqual(len(qs.exclude(publisher=-1)), 6) - @skipUnlessDBFeature("allows_group_by_pk") + @skipUnlessAnyDBFeature('allows_group_by_pk', 'allows_group_by_selected_pks') def test_aggregate_duplicate_columns(self): # Regression test for #17144 @@ -1041,7 +1042,7 @@ class AggregationTests(TestCase): ] ) - @skipUnlessDBFeature("allows_group_by_pk") + @skipUnlessAnyDBFeature('allows_group_by_pk', 'allows_group_by_selected_pks') def test_aggregate_duplicate_columns_only(self): # Works with only() too. results = Author.objects.only('id', 'name').annotate(num_contacts=Count('book_contact_set')) @@ -1067,13 +1068,14 @@ class AggregationTests(TestCase): ] ) - @skipUnlessDBFeature("allows_group_by_pk") + @skipUnlessAnyDBFeature('allows_group_by_pk', 'allows_group_by_selected_pks') def test_aggregate_duplicate_columns_select_related(self): # And select_related() results = Book.objects.select_related('contact').annotate( num_authors=Count('authors')) _, _, grouping = results.query.get_compiler(using='default').pre_sql_setup() - self.assertEqual(len(grouping), 1) + # In the case of `group_by_selected_pks` we also group by contact.id because of the select_related. + self.assertEqual(len(grouping), 1 if connection.features.allows_group_by_pk else 2) self.assertIn('id', grouping[0][0]) self.assertNotIn('name', grouping[0][0]) self.assertNotIn('contact', grouping[0][0])