mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Pass values through get_db_prep_save() in a QuerySet.update() call.
This removes a long-standing FIXME in the update() handling and allows for greater flexibility in the values passed in. In particular, it brings updates into line with saves for django.contrib.gis fields, so fixed #10411. Thanks to Justin Bronn and Russell Keith-Magee for help with this patch. git-svn-id: http://code.djangoproject.com/svn/django/trunk@10003 bcc190cf-cafb-0310-a4f2-bffc1f526a37
This commit is contained in:
		| @@ -12,3 +12,6 @@ class WKTAdaptor(object): | |||||||
|  |  | ||||||
|     def __str__(self): |     def __str__(self): | ||||||
|         return self.wkt |         return self.wkt | ||||||
|  |  | ||||||
|  |     def prepare_database_save(self, unused): | ||||||
|  |         return self | ||||||
|   | |||||||
| @@ -31,3 +31,6 @@ class PostGISAdaptor(object): | |||||||
|         "Returns a properly quoted string for use in PostgreSQL/PostGIS." |         "Returns a properly quoted string for use in PostgreSQL/PostGIS." | ||||||
|         # Want to use WKB, so wrap with psycopg2 Binary() to quote properly. |         # Want to use WKB, so wrap with psycopg2 Binary() to quote properly. | ||||||
|         return "%s(%s, %s)" % (GEOM_FROM_WKB, Binary(self.wkb), self.srid or -1) |         return "%s(%s, %s)" % (GEOM_FROM_WKB, Binary(self.wkb), self.srid or -1) | ||||||
|  |  | ||||||
|  |     def prepare_database_save(self, unused): | ||||||
|  |         return self | ||||||
|   | |||||||
							
								
								
									
										18
									
								
								django/contrib/gis/tests/geoapp/test_regress.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								django/contrib/gis/tests/geoapp/test_regress.py
									
									
									
									
									
										Normal file
									
								
							| @@ -0,0 +1,18 @@ | |||||||
|  | import os, unittest | ||||||
|  | from django.contrib.gis.db.backend import SpatialBackend | ||||||
|  | from django.contrib.gis.tests.utils import no_mysql, no_oracle, no_postgis | ||||||
|  | from models import City | ||||||
|  |  | ||||||
|  | class GeoRegressionTests(unittest.TestCase): | ||||||
|  |  | ||||||
|  |     def test01_update(self): | ||||||
|  |         "Testing GeoQuerySet.update(), see #10411." | ||||||
|  |         pnt = City.objects.get(name='Pueblo').point | ||||||
|  |         bak = pnt.clone() | ||||||
|  |         pnt.y += 0.005 | ||||||
|  |         pnt.x += 0.005 | ||||||
|  |  | ||||||
|  |         City.objects.filter(name='Pueblo').update(point=pnt) | ||||||
|  |         self.assertEqual(pnt, City.objects.get(name='Pueblo').point) | ||||||
|  |         City.objects.filter(name='Pueblo').update(point=bak) | ||||||
|  |         self.assertEqual(bak, City.objects.get(name='Pueblo').point) | ||||||
| @@ -204,8 +204,8 @@ class GeoModelTest(unittest.TestCase): | |||||||
|         self.assertRaises(TypeError, Country.objects.make_line) |         self.assertRaises(TypeError, Country.objects.make_line) | ||||||
|         # Reference query: |         # Reference query: | ||||||
|         # SELECT AsText(ST_MakeLine(geoapp_city.point)) FROM geoapp_city; |         # SELECT AsText(ST_MakeLine(geoapp_city.point)) FROM geoapp_city; | ||||||
|         self.assertEqual(GEOSGeometry('LINESTRING(-95.363151 29.763374,-96.801611 32.782057,-97.521157 34.464642,174.783117 -41.315268,-104.609252 38.255001,-95.23506 38.971823,-87.650175 41.850385,-123.305196 48.462611)', srid=4326), |         ref_line = GEOSGeometry('LINESTRING(-95.363151 29.763374,-96.801611 32.782057,-97.521157 34.464642,174.783117 -41.315268,-104.609252 38.255001,-95.23506 38.971823,-87.650175 41.850385,-123.305196 48.462611)', srid=4326) | ||||||
|                          City.objects.make_line()) |         self.assertEqual(ref_line, City.objects.make_line()) | ||||||
|  |  | ||||||
|     def test09_disjoint(self): |     def test09_disjoint(self): | ||||||
|         "Testing the `disjoint` lookup type." |         "Testing the `disjoint` lookup type." | ||||||
| @@ -571,10 +571,13 @@ class GeoModelTest(unittest.TestCase): | |||||||
|         for pc in qs: self.assertEqual(32128, pc.point.srid) |         for pc in qs: self.assertEqual(32128, pc.point.srid) | ||||||
|  |  | ||||||
| from test_feeds import GeoFeedTest | from test_feeds import GeoFeedTest | ||||||
|  | from test_regress import GeoRegressionTests | ||||||
| from test_sitemaps import GeoSitemapTest | from test_sitemaps import GeoSitemapTest | ||||||
|  |  | ||||||
| def suite(): | def suite(): | ||||||
|     s = unittest.TestSuite() |     s = unittest.TestSuite() | ||||||
|     s.addTest(unittest.makeSuite(GeoModelTest)) |     s.addTest(unittest.makeSuite(GeoModelTest)) | ||||||
|     s.addTest(unittest.makeSuite(GeoFeedTest)) |     s.addTest(unittest.makeSuite(GeoFeedTest)) | ||||||
|     s.addTest(unittest.makeSuite(GeoSitemapTest)) |     s.addTest(unittest.makeSuite(GeoSitemapTest)) | ||||||
|  |     s.addTest(unittest.makeSuite(GeoRegressionTests)) | ||||||
|     return s |     return s | ||||||
|   | |||||||
| @@ -174,10 +174,13 @@ class GeoModelTest(unittest.TestCase): | |||||||
|         self.assertRaises(ImproperlyConfigured, Country.objects.all().gml, field_name='mpoly') |         self.assertRaises(ImproperlyConfigured, Country.objects.all().gml, field_name='mpoly') | ||||||
|  |  | ||||||
| from test_feeds import GeoFeedTest | from test_feeds import GeoFeedTest | ||||||
|  | from test_regress import GeoRegressionTests | ||||||
| from test_sitemaps import GeoSitemapTest | from test_sitemaps import GeoSitemapTest | ||||||
|  |  | ||||||
| def suite(): | def suite(): | ||||||
|     s = unittest.TestSuite() |     s = unittest.TestSuite() | ||||||
|     s.addTest(unittest.makeSuite(GeoModelTest)) |     s.addTest(unittest.makeSuite(GeoModelTest)) | ||||||
|     s.addTest(unittest.makeSuite(GeoFeedTest)) |     s.addTest(unittest.makeSuite(GeoFeedTest)) | ||||||
|     s.addTest(unittest.makeSuite(GeoSitemapTest)) |     s.addTest(unittest.makeSuite(GeoSitemapTest)) | ||||||
|  |     s.addTest(unittest.makeSuite(GeoRegressionTests)) | ||||||
|     return s |     return s | ||||||
|   | |||||||
| @@ -499,6 +499,8 @@ class Model(object): | |||||||
|             setattr(self, cachename, obj) |             setattr(self, cachename, obj) | ||||||
|         return getattr(self, cachename) |         return getattr(self, cachename) | ||||||
|  |  | ||||||
|  |     def prepare_database_save(self, unused): | ||||||
|  |         return self.pk | ||||||
|  |  | ||||||
|  |  | ||||||
| ############################################ | ############################################ | ||||||
|   | |||||||
| @@ -90,6 +90,9 @@ class ExpressionNode(tree.Node): | |||||||
|     def __ror__(self, other): |     def __ror__(self, other): | ||||||
|         return self._combine(other, self.OR, True) |         return self._combine(other, self.OR, True) | ||||||
|  |  | ||||||
|  |     def prepare_database_save(self, unused): | ||||||
|  |         return self | ||||||
|  |  | ||||||
| class F(ExpressionNode): | class F(ExpressionNode): | ||||||
|     """ |     """ | ||||||
|     An expression representing the value of the given field. |     An expression representing the value of the given field. | ||||||
|   | |||||||
| @@ -239,9 +239,10 @@ class UpdateQuery(Query): | |||||||
|         """ |         """ | ||||||
|         from django.db.models.base import Model |         from django.db.models.base import Model | ||||||
|         for field, model, val in values_seq: |         for field, model, val in values_seq: | ||||||
|             # FIXME: Some sort of db_prep_* is probably more appropriate here. |             if hasattr(val, 'prepare_database_save'): | ||||||
|             if field.rel and isinstance(val, Model): |                 val = val.prepare_database_save(field) | ||||||
|                 val = val.pk |             else: | ||||||
|  |                 val = field.get_db_prep_save(val) | ||||||
|  |  | ||||||
|             # Getting the placeholder for the field. |             # Getting the placeholder for the field. | ||||||
|             if hasattr(field, 'get_placeholder'): |             if hasattr(field, 'get_placeholder'): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user