mirror of
				https://github.com/django/django.git
				synced 2025-10-31 09:41:08 +00:00 
			
		
		
		
	Fixed #22648 -- Transform.output_type should respect overridden custom_lookup and custom_transform.
Previously, class lookups from the output_type would be used, but any changes to custom_lookup or custom_transform would be ignored.
This commit is contained in:
		| @@ -22,18 +22,20 @@ class RegisterLookupMixin(object): | ||||
|         except AttributeError: | ||||
|             # This class didn't have any class_lookups | ||||
|             pass | ||||
|         if hasattr(self, 'output_type'): | ||||
|             return self.output_type.get_lookup(lookup_name) | ||||
|         return None | ||||
|  | ||||
|     def get_lookup(self, lookup_name): | ||||
|         found = self._get_lookup(lookup_name) | ||||
|         if found is None and hasattr(self, 'output_type'): | ||||
|             return self.output_type.get_lookup(lookup_name) | ||||
|         if found is not None and not issubclass(found, Lookup): | ||||
|             return None | ||||
|         return found | ||||
|  | ||||
|     def get_transform(self, lookup_name): | ||||
|         found = self._get_lookup(lookup_name) | ||||
|         if found is None and hasattr(self, 'output_type'): | ||||
|             return self.output_type.get_transform(lookup_name) | ||||
|         if found is not None and not issubclass(found, Transform): | ||||
|             return None | ||||
|         return found | ||||
|   | ||||
| @@ -89,6 +89,47 @@ class YearLte(models.lookups.LessThanOrEqual): | ||||
| YearTransform.register_lookup(YearLte) | ||||
|  | ||||
|  | ||||
| class SQLFunc(models.Lookup): | ||||
|     def __init__(self, name, *args, **kwargs): | ||||
|         super(SQLFunc, self).__init__(*args, **kwargs) | ||||
|         self.name = name | ||||
|  | ||||
|     def as_sql(self, qn, connection): | ||||
|         return '%s()', [self.name] | ||||
|  | ||||
|     @property | ||||
|     def output_type(self): | ||||
|         return CustomField() | ||||
|  | ||||
|  | ||||
| class SQLFuncFactory(object): | ||||
|  | ||||
|     def __init__(self, name): | ||||
|         self.name = name | ||||
|  | ||||
|     def __call__(self, *args, **kwargs): | ||||
|         return SQLFunc(self.name, *args, **kwargs) | ||||
|  | ||||
|  | ||||
| class CustomField(models.Field): | ||||
|  | ||||
|     def get_lookup(self, lookup_name): | ||||
|         if lookup_name.startswith('lookupfunc_'): | ||||
|             key, name = lookup_name.split('_', 1) | ||||
|             return SQLFuncFactory(name) | ||||
|         return super(CustomField, self).get_lookup(lookup_name) | ||||
|  | ||||
|     def get_transform(self, lookup_name): | ||||
|         if lookup_name.startswith('transformfunc_'): | ||||
|             key, name = lookup_name.split('_', 1) | ||||
|             return SQLFuncFactory(name) | ||||
|         return super(CustomField, self).get_transform(lookup_name) | ||||
|  | ||||
|  | ||||
| class CustomModel(models.Model): | ||||
|     field = CustomField() | ||||
|  | ||||
|  | ||||
| # We will register this class temporarily in the test method. | ||||
|  | ||||
|  | ||||
| @@ -341,3 +382,22 @@ class LookupTransformCallOrderTests(TestCase): | ||||
|  | ||||
|         finally: | ||||
|             models.DateField._unregister_lookup(TrackCallsYearTransform) | ||||
|  | ||||
|  | ||||
| class CustomisedMethodsTests(TestCase): | ||||
|  | ||||
|     def test_overridden_get_lookup(self): | ||||
|         q = CustomModel.objects.filter(field__lookupfunc_monkeys=3) | ||||
|         self.assertIn('monkeys()', str(q.query)) | ||||
|  | ||||
|     def test_overridden_get_transform(self): | ||||
|         q = CustomModel.objects.filter(field__transformfunc_banana=3) | ||||
|         self.assertIn('banana()', str(q.query)) | ||||
|  | ||||
|     def test_overridden_get_lookup_chain(self): | ||||
|         q = CustomModel.objects.filter(field__transformfunc_banana__lookupfunc_elephants=3) | ||||
|         self.assertIn('elephants()', str(q.query)) | ||||
|  | ||||
|     def test_overridden_get_transform_chain(self): | ||||
|         q = CustomModel.objects.filter(field__transformfunc_banana__transformfunc_pear=3) | ||||
|         self.assertIn('pear()', str(q.query)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user