mirror of
				https://github.com/django/django.git
				synced 2025-10-25 06:36:07 +00:00 
			
		
		
		
	Refs #19544 -- Extracted ManyRelatedManager.add() missing ids logic to a method.
This commit is contained in:
		
				
					committed by
					
						 Tim Graham
						Tim Graham
					
				
			
			
				
	
			
			
			
						parent
						
							0ac4e51b2c
						
					
				
				
					commit
					dd32f9a3a2
				
			| @@ -1051,6 +1051,44 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): | |||||||
|             return obj, created |             return obj, created | ||||||
|         update_or_create.alters_data = True |         update_or_create.alters_data = True | ||||||
|  |  | ||||||
|  |         def _get_missing_target_ids(self, source_field_name, target_field_name, db, objs): | ||||||
|  |             """ | ||||||
|  |             Return the subset of ids of `objs` that aren't already assigned to | ||||||
|  |             this relationship. | ||||||
|  |             """ | ||||||
|  |             from django.db.models import Model | ||||||
|  |             target_ids = set() | ||||||
|  |             target_field = self.through._meta.get_field(target_field_name) | ||||||
|  |             for obj in objs: | ||||||
|  |                 if isinstance(obj, self.model): | ||||||
|  |                     if not router.allow_relation(obj, self.instance): | ||||||
|  |                         raise ValueError( | ||||||
|  |                             'Cannot add "%r": instance is on database "%s", ' | ||||||
|  |                             'value is on database "%s"' % | ||||||
|  |                             (obj, self.instance._state.db, obj._state.db) | ||||||
|  |                         ) | ||||||
|  |                     target_id = target_field.get_foreign_related_value(obj)[0] | ||||||
|  |                     if target_id is None: | ||||||
|  |                         raise ValueError( | ||||||
|  |                             'Cannot add "%r": the value for field "%s" is None' % | ||||||
|  |                             (obj, target_field_name) | ||||||
|  |                         ) | ||||||
|  |                     target_ids.add(target_id) | ||||||
|  |                 elif isinstance(obj, Model): | ||||||
|  |                     raise TypeError( | ||||||
|  |                         "'%s' instance expected, got %r" % | ||||||
|  |                         (self.model._meta.object_name, obj) | ||||||
|  |                     ) | ||||||
|  |                 else: | ||||||
|  |                     target_ids.add(obj) | ||||||
|  |             vals = self.through._default_manager.using(db).values_list( | ||||||
|  |                 target_field_name, flat=True | ||||||
|  |             ).filter(**{ | ||||||
|  |                 source_field_name: self.related_val[0], | ||||||
|  |                 '%s__in' % target_field_name: target_ids, | ||||||
|  |             }) | ||||||
|  |             return target_ids.difference(vals) | ||||||
|  |  | ||||||
|         def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None): |         def _add_items(self, source_field_name, target_field_name, *objs, through_defaults=None): | ||||||
|             # source_field_name: the PK fieldname in join table for the source object |             # source_field_name: the PK fieldname in join table for the source object | ||||||
|             # target_field_name: the PK fieldname in join table for the target object |             # target_field_name: the PK fieldname in join table for the target object | ||||||
| @@ -1058,40 +1096,9 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): | |||||||
|             through_defaults = through_defaults or {} |             through_defaults = through_defaults or {} | ||||||
|  |  | ||||||
|             # If there aren't any objects, there is nothing to do. |             # If there aren't any objects, there is nothing to do. | ||||||
|             from django.db.models import Model |  | ||||||
|             if objs: |             if objs: | ||||||
|                 new_ids = set() |  | ||||||
|                 for obj in objs: |  | ||||||
|                     if isinstance(obj, self.model): |  | ||||||
|                         if not router.allow_relation(obj, self.instance): |  | ||||||
|                             raise ValueError( |  | ||||||
|                                 'Cannot add "%r": instance is on database "%s", value is on database "%s"' % |  | ||||||
|                                 (obj, self.instance._state.db, obj._state.db) |  | ||||||
|                             ) |  | ||||||
|                         fk_val = self.through._meta.get_field( |  | ||||||
|                             target_field_name).get_foreign_related_value(obj)[0] |  | ||||||
|                         if fk_val is None: |  | ||||||
|                             raise ValueError( |  | ||||||
|                                 'Cannot add "%r": the value for field "%s" is None' % |  | ||||||
|                                 (obj, target_field_name) |  | ||||||
|                             ) |  | ||||||
|                         new_ids.add(fk_val) |  | ||||||
|                     elif isinstance(obj, Model): |  | ||||||
|                         raise TypeError( |  | ||||||
|                             "'%s' instance expected, got %r" % |  | ||||||
|                             (self.model._meta.object_name, obj) |  | ||||||
|                         ) |  | ||||||
|                     else: |  | ||||||
|                         new_ids.add(obj) |  | ||||||
|  |  | ||||||
|                 db = router.db_for_write(self.through, instance=self.instance) |                 db = router.db_for_write(self.through, instance=self.instance) | ||||||
|                 vals = (self.through._default_manager.using(db) |                 missing_target_ids = self._get_missing_target_ids(source_field_name, target_field_name, db, objs) | ||||||
|                         .values_list(target_field_name, flat=True) |  | ||||||
|                         .filter(**{ |  | ||||||
|                             source_field_name: self.related_val[0], |  | ||||||
|                             '%s__in' % target_field_name: new_ids, |  | ||||||
|                         })) |  | ||||||
|                 new_ids.difference_update(vals) |  | ||||||
|  |  | ||||||
|                 with transaction.atomic(using=db, savepoint=False): |                 with transaction.atomic(using=db, savepoint=False): | ||||||
|                     if self.reverse or source_field_name == self.source_field_name: |                     if self.reverse or source_field_name == self.source_field_name: | ||||||
| @@ -1100,16 +1107,16 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): | |||||||
|                         signals.m2m_changed.send( |                         signals.m2m_changed.send( | ||||||
|                             sender=self.through, action='pre_add', |                             sender=self.through, action='pre_add', | ||||||
|                             instance=self.instance, reverse=self.reverse, |                             instance=self.instance, reverse=self.reverse, | ||||||
|                             model=self.model, pk_set=new_ids, using=db, |                             model=self.model, pk_set=missing_target_ids, using=db, | ||||||
|                         ) |                         ) | ||||||
|  |  | ||||||
|                     # Add the ones that aren't there already |                     # Add the ones that aren't there already | ||||||
|                     self.through._default_manager.using(db).bulk_create([ |                     self.through._default_manager.using(db).bulk_create([ | ||||||
|                         self.through(**through_defaults, **{ |                         self.through(**through_defaults, **{ | ||||||
|                             '%s_id' % source_field_name: self.related_val[0], |                             '%s_id' % source_field_name: self.related_val[0], | ||||||
|                             '%s_id' % target_field_name: obj_id, |                             '%s_id' % target_field_name: target_id, | ||||||
|                         }) |                         }) | ||||||
|                         for obj_id in new_ids |                         for target_id in missing_target_ids | ||||||
|                     ]) |                     ]) | ||||||
|  |  | ||||||
|                     if self.reverse or source_field_name == self.source_field_name: |                     if self.reverse or source_field_name == self.source_field_name: | ||||||
| @@ -1118,7 +1125,7 @@ def create_forward_many_to_many_manager(superclass, rel, reverse): | |||||||
|                         signals.m2m_changed.send( |                         signals.m2m_changed.send( | ||||||
|                             sender=self.through, action='post_add', |                             sender=self.through, action='post_add', | ||||||
|                             instance=self.instance, reverse=self.reverse, |                             instance=self.instance, reverse=self.reverse, | ||||||
|                             model=self.model, pk_set=new_ids, using=db, |                             model=self.model, pk_set=missing_target_ids, using=db, | ||||||
|                         ) |                         ) | ||||||
|  |  | ||||||
|         def _remove_items(self, source_field_name, target_field_name, *objs): |         def _remove_items(self, source_field_name, target_field_name, *objs): | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user