diff --git a/django/contrib/messages/storage/base.py b/django/contrib/messages/storage/base.py index fd5d0c24aa..b2eeac77f4 100644 --- a/django/contrib/messages/storage/base.py +++ b/django/contrib/messages/storage/base.py @@ -25,8 +25,9 @@ class Message: self.extra_tags = str(self.extra_tags) if self.extra_tags is not None else None def __eq__(self, other): - return isinstance(other, Message) and self.level == other.level and \ - self.message == other.message + if not isinstance(other, Message): + return NotImplemented + return self.level == other.level and self.message == other.message def __str__(self): return str(self.message) diff --git a/django/contrib/postgres/constraints.py b/django/contrib/postgres/constraints.py index 2fcb076ecf..67e415ddcf 100644 --- a/django/contrib/postgres/constraints.py +++ b/django/contrib/postgres/constraints.py @@ -89,13 +89,14 @@ class ExclusionConstraint(BaseConstraint): return path, args, kwargs def __eq__(self, other): - return ( - isinstance(other, self.__class__) and - self.name == other.name and - self.index_type == other.index_type and - self.expressions == other.expressions and - self.condition == other.condition - ) + if isinstance(other, self.__class__): + return ( + self.name == other.name and + self.index_type == other.index_type and + self.expressions == other.expressions and + self.condition == other.condition + ) + return super().__eq__(other) def __repr__(self): return '<%s: index_type=%s, expressions=%s%s>' % ( diff --git a/django/core/validators.py b/django/core/validators.py index 2e00ca3ff3..38345a844f 100644 --- a/django/core/validators.py +++ b/django/core/validators.py @@ -324,8 +324,9 @@ class BaseValidator: raise ValidationError(self.message, code=self.code, params=params) def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented return ( - isinstance(other, self.__class__) and self.limit_value == other.limit_value and self.message == other.message and self.code == other.code diff --git a/django/db/models/base.py b/django/db/models/base.py index 0b8425aa85..0a5e5ff673 100644 --- a/django/db/models/base.py +++ b/django/db/models/base.py @@ -522,7 +522,7 @@ class Model(metaclass=ModelBase): def __eq__(self, other): if not isinstance(other, Model): - return False + return NotImplemented if self._meta.concrete_model != other._meta.concrete_model: return False my_pk = self.pk diff --git a/django/db/models/constraints.py b/django/db/models/constraints.py index e7f81d3ee9..fe0d42a168 100644 --- a/django/db/models/constraints.py +++ b/django/db/models/constraints.py @@ -54,11 +54,9 @@ class CheckConstraint(BaseConstraint): return "<%s: check='%s' name=%r>" % (self.__class__.__name__, self.check, self.name) def __eq__(self, other): - return ( - isinstance(other, CheckConstraint) and - self.name == other.name and - self.check == other.check - ) + if isinstance(other, CheckConstraint): + return self.name == other.name and self.check == other.check + return super().__eq__(other) def deconstruct(self): path, args, kwargs = super().deconstruct() @@ -106,12 +104,13 @@ class UniqueConstraint(BaseConstraint): ) def __eq__(self, other): - return ( - isinstance(other, UniqueConstraint) and - self.name == other.name and - self.fields == other.fields and - self.condition == other.condition - ) + if isinstance(other, UniqueConstraint): + return ( + self.name == other.name and + self.fields == other.fields and + self.condition == other.condition + ) + return super().__eq__(other) def deconstruct(self): path, args, kwargs = super().deconstruct() diff --git a/django/db/models/expressions.py b/django/db/models/expressions.py index 2b59dd301a..5df765b626 100644 --- a/django/db/models/expressions.py +++ b/django/db/models/expressions.py @@ -401,7 +401,9 @@ class BaseExpression: return tuple(identity) def __eq__(self, other): - return isinstance(other, BaseExpression) and other.identity == self.identity + if not isinstance(other, BaseExpression): + return NotImplemented + return other.identity == self.identity def __hash__(self): return hash(self.identity) diff --git a/django/db/models/indexes.py b/django/db/models/indexes.py index b156366764..49f4989462 100644 --- a/django/db/models/indexes.py +++ b/django/db/models/indexes.py @@ -112,4 +112,6 @@ class Index: ) def __eq__(self, other): - return (self.__class__ == other.__class__) and (self.deconstruct() == other.deconstruct()) + if self.__class__ == other.__class__: + return self.deconstruct() == other.deconstruct() + return NotImplemented diff --git a/django/db/models/query.py b/django/db/models/query.py index 4417c17592..794e0faae7 100644 --- a/django/db/models/query.py +++ b/django/db/models/query.py @@ -1543,7 +1543,9 @@ class Prefetch: return None def __eq__(self, other): - return isinstance(other, Prefetch) and self.prefetch_to == other.prefetch_to + if not isinstance(other, Prefetch): + return NotImplemented + return self.prefetch_to == other.prefetch_to def __hash__(self): return hash((self.__class__, self.prefetch_to)) diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py index 7a667814f4..189fb4fa44 100644 --- a/django/db/models/query_utils.py +++ b/django/db/models/query_utils.py @@ -309,8 +309,9 @@ class FilteredRelation: self.path = [] def __eq__(self, other): + if not isinstance(other, self.__class__): + return NotImplemented return ( - isinstance(other, self.__class__) and self.relation_name == other.relation_name and self.alias == other.alias and self.condition == other.condition diff --git a/django/template/context.py b/django/template/context.py index 8f349a3a96..f0a0cf2a00 100644 --- a/django/template/context.py +++ b/django/template/context.py @@ -124,12 +124,10 @@ class BaseContext: """ Compare two contexts by comparing theirs 'dicts' attributes. """ - return ( - isinstance(other, BaseContext) and - # because dictionaries can be put in different order - # we have to flatten them like in templates - self.flatten() == other.flatten() - ) + if not isinstance(other, BaseContext): + return NotImplemented + # flatten dictionaries because they can be put in a different order. + return self.flatten() == other.flatten() class Context(BaseContext): diff --git a/tests/basic/tests.py b/tests/basic/tests.py index 89f6048c96..5eada343e1 100644 --- a/tests/basic/tests.py +++ b/tests/basic/tests.py @@ -1,5 +1,6 @@ import threading from datetime import datetime, timedelta +from unittest import mock from django.core.exceptions import MultipleObjectsReturned, ObjectDoesNotExist from django.db import DEFAULT_DB_ALIAS, DatabaseError, connections, models @@ -354,6 +355,7 @@ class ModelTest(TestCase): self.assertNotEqual(object(), Article(id=1)) a = Article() self.assertEqual(a, a) + self.assertEqual(a, mock.ANY) self.assertNotEqual(Article(), a) def test_hash(self): diff --git a/tests/constraints/tests.py b/tests/constraints/tests.py index 3b28c99e7f..8e2eb11e2a 100644 --- a/tests/constraints/tests.py +++ b/tests/constraints/tests.py @@ -1,3 +1,5 @@ +from unittest import mock + from django.core.exceptions import ValidationError from django.db import IntegrityError, connection, models from django.db.models.constraints import BaseConstraint @@ -39,6 +41,7 @@ class CheckConstraintTests(TestCase): models.CheckConstraint(check=check1, name='price'), models.CheckConstraint(check=check1, name='price'), ) + self.assertEqual(models.CheckConstraint(check=check1, name='price'), mock.ANY) self.assertNotEqual( models.CheckConstraint(check=check1, name='price'), models.CheckConstraint(check=check1, name='price2'), @@ -102,6 +105,10 @@ class UniqueConstraintTests(TestCase): models.UniqueConstraint(fields=['foo', 'bar'], name='unique'), models.UniqueConstraint(fields=['foo', 'bar'], name='unique'), ) + self.assertEqual( + models.UniqueConstraint(fields=['foo', 'bar'], name='unique'), + mock.ANY, + ) self.assertNotEqual( models.UniqueConstraint(fields=['foo', 'bar'], name='unique'), models.UniqueConstraint(fields=['foo', 'bar'], name='unique2'), diff --git a/tests/expressions/tests.py b/tests/expressions/tests.py index f50c634014..094b738792 100644 --- a/tests/expressions/tests.py +++ b/tests/expressions/tests.py @@ -3,6 +3,7 @@ import pickle import unittest import uuid from copy import deepcopy +from unittest import mock from django.core.exceptions import FieldError from django.db import DatabaseError, connection, models @@ -965,6 +966,7 @@ class SimpleExpressionTests(SimpleTestCase): Expression(models.IntegerField()), Expression(output_field=models.IntegerField()) ) + self.assertEqual(Expression(models.IntegerField()), mock.ANY) self.assertNotEqual( Expression(models.IntegerField()), Expression(models.CharField()) diff --git a/tests/filtered_relation/tests.py b/tests/filtered_relation/tests.py index 52fe64dfa5..48154413a5 100644 --- a/tests/filtered_relation/tests.py +++ b/tests/filtered_relation/tests.py @@ -1,3 +1,5 @@ +from unittest import mock + from django.db import connection, transaction from django.db.models import Case, Count, F, FilteredRelation, Q, When from django.test import TestCase @@ -323,6 +325,9 @@ class FilteredRelationTests(TestCase): [self.book1] ) + def test_eq(self): + self.assertEqual(FilteredRelation('book', condition=Q(book__title='b')), mock.ANY) + class FilteredRelationAggregationTests(TestCase): diff --git a/tests/messages_tests/tests.py b/tests/messages_tests/tests.py index 1464783b33..eea07c9c41 100644 --- a/tests/messages_tests/tests.py +++ b/tests/messages_tests/tests.py @@ -1,3 +1,5 @@ +from unittest import mock + from django.contrib.messages import constants from django.contrib.messages.storage.base import Message from django.test import SimpleTestCase @@ -9,6 +11,7 @@ class MessageTests(SimpleTestCase): msg_2 = Message(constants.INFO, 'Test message 2') msg_3 = Message(constants.WARNING, 'Test message 1') self.assertEqual(msg_1, msg_1) + self.assertEqual(msg_1, mock.ANY) self.assertNotEqual(msg_1, msg_2) self.assertNotEqual(msg_1, msg_3) self.assertNotEqual(msg_2, msg_3) diff --git a/tests/model_indexes/tests.py b/tests/model_indexes/tests.py index ade27e1a4b..6a31109031 100644 --- a/tests/model_indexes/tests.py +++ b/tests/model_indexes/tests.py @@ -1,3 +1,5 @@ +from unittest import mock + from django.conf import settings from django.db import connection, models from django.db.models.query_utils import Q @@ -28,6 +30,7 @@ class SimpleIndexesTests(SimpleTestCase): same_index.model = Book another_index.model = Book self.assertEqual(index, same_index) + self.assertEqual(index, mock.ANY) self.assertNotEqual(index, another_index) def test_index_fields_type(self): diff --git a/tests/postgres_tests/test_constraints.py b/tests/postgres_tests/test_constraints.py index d8665f59f6..b22821294a 100644 --- a/tests/postgres_tests/test_constraints.py +++ b/tests/postgres_tests/test_constraints.py @@ -1,4 +1,5 @@ import datetime +from unittest import mock from django.db import connection, transaction from django.db.models import F, Func, Q @@ -175,6 +176,7 @@ class ExclusionConstraintTests(PostgreSQLTestCase): condition=Q(cancelled=False), ) self.assertEqual(constraint_1, constraint_1) + self.assertEqual(constraint_1, mock.ANY) self.assertNotEqual(constraint_1, constraint_2) self.assertNotEqual(constraint_1, constraint_3) self.assertNotEqual(constraint_2, constraint_3) diff --git a/tests/prefetch_related/tests.py b/tests/prefetch_related/tests.py index 9ae939dcdf..930ba9fbc8 100644 --- a/tests/prefetch_related/tests.py +++ b/tests/prefetch_related/tests.py @@ -1,3 +1,5 @@ +from unittest import mock + from django.contrib.contenttypes.models import ContentType from django.core.exceptions import ObjectDoesNotExist from django.db import connection @@ -243,6 +245,7 @@ class PrefetchRelatedTests(TestDataMixin, TestCase): prefetch_1 = Prefetch('authors', queryset=Author.objects.all()) prefetch_2 = Prefetch('books', queryset=Book.objects.all()) self.assertEqual(prefetch_1, prefetch_1) + self.assertEqual(prefetch_1, mock.ANY) self.assertNotEqual(prefetch_1, prefetch_2) def test_forward_m2m_to_attr_conflict(self): diff --git a/tests/template_tests/test_context.py b/tests/template_tests/test_context.py index 8c6fc98b42..1150a14639 100644 --- a/tests/template_tests/test_context.py +++ b/tests/template_tests/test_context.py @@ -1,3 +1,5 @@ +from unittest import mock + from django.http import HttpRequest from django.template import ( Context, Engine, RequestContext, Template, Variable, VariableDoesNotExist, @@ -18,6 +20,7 @@ class ContextTests(SimpleTestCase): self.assertEqual(c.pop(), {"a": 2}) self.assertEqual(c["a"], 1) self.assertEqual(c.get("foo", 42), 42) + self.assertEqual(c, mock.ANY) def test_push_context_manager(self): c = Context({"a": 1}) diff --git a/tests/validators/tests.py b/tests/validators/tests.py index 36d0b2a520..295c6c899f 100644 --- a/tests/validators/tests.py +++ b/tests/validators/tests.py @@ -3,7 +3,7 @@ import re import types from datetime import datetime, timedelta from decimal import Decimal -from unittest import TestCase +from unittest import TestCase, mock from django.core.exceptions import ValidationError from django.core.files.base import ContentFile @@ -424,6 +424,7 @@ class TestValidatorEquality(TestCase): MaxValueValidator(44), MaxValueValidator(44), ) + self.assertEqual(MaxValueValidator(44), mock.ANY) self.assertNotEqual( MaxValueValidator(44), MinValueValidator(44),