mirror of
				https://github.com/django/django.git
				synced 2025-10-24 22:26:08 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			52 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			52 lines
		
	
	
		
			1.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from django.db import connection, models
 | |
| from django.test import SimpleTestCase
 | |
| 
 | |
| from .utils import FuncTestMixin
 | |
| 
 | |
| 
 | |
| def test_mutation(raises=True):
 | |
|     def wrapper(mutation_func):
 | |
|         def test(test_case_instance, *args, **kwargs):
 | |
|             class TestFunc(models.Func):
 | |
|                 output_field = models.IntegerField()
 | |
| 
 | |
|                 def __init__(self):
 | |
|                     self.attribute = 'initial'
 | |
|                     super().__init__('initial', ['initial'])
 | |
| 
 | |
|                 def as_sql(self, *args, **kwargs):
 | |
|                     mutation_func(self)
 | |
|                     return '', ()
 | |
| 
 | |
|             if raises:
 | |
|                 msg = 'TestFunc Func was mutated during compilation.'
 | |
|                 with test_case_instance.assertRaisesMessage(AssertionError, msg):
 | |
|                     getattr(TestFunc(), 'as_' + connection.vendor)(None, None)
 | |
|             else:
 | |
|                 getattr(TestFunc(), 'as_' + connection.vendor)(None, None)
 | |
| 
 | |
|         return test
 | |
|     return wrapper
 | |
| 
 | |
| 
 | |
| class FuncTestMixinTests(FuncTestMixin, SimpleTestCase):
 | |
|     @test_mutation()
 | |
|     def test_mutated_attribute(func):
 | |
|         func.attribute = 'mutated'
 | |
| 
 | |
|     @test_mutation()
 | |
|     def test_mutated_expressions(func):
 | |
|         func.source_expressions.clear()
 | |
| 
 | |
|     @test_mutation()
 | |
|     def test_mutated_expression(func):
 | |
|         func.source_expressions[0].name = 'mutated'
 | |
| 
 | |
|     @test_mutation()
 | |
|     def test_mutated_expression_deep(func):
 | |
|         func.source_expressions[1].value[0] = 'mutated'
 | |
| 
 | |
|     @test_mutation(raises=False)
 | |
|     def test_not_mutated(func):
 | |
|         pass
 |