mirror of
				https://github.com/django/django.git
				synced 2025-10-25 22:56:12 +00:00 
			
		
		
		
	Fixed #19501 -- added Model.from_db() method
The Model.from_db() is intended to be used in cases where customization of model loading is needed. Reasons can be performance, or adding custom behavior to the model (for example "dirty field tracking" to issue automatic update_fields when saving models). A big thank you to Tim Graham for the review!
This commit is contained in:
		| @@ -458,6 +458,16 @@ class Model(six.with_metaclass(ModelBase)): | |||||||
|         super(Model, self).__init__() |         super(Model, self).__init__() | ||||||
|         signals.post_init.send(sender=self.__class__, instance=self) |         signals.post_init.send(sender=self.__class__, instance=self) | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def from_db(cls, db, field_names, values): | ||||||
|  |         if cls._deferred: | ||||||
|  |             new = cls(**dict(zip(field_names, values))) | ||||||
|  |         else: | ||||||
|  |             new = cls(*values) | ||||||
|  |         new._state.adding = False | ||||||
|  |         new._state.db = db | ||||||
|  |         return new | ||||||
|  |  | ||||||
|     def __repr__(self): |     def __repr__(self): | ||||||
|         try: |         try: | ||||||
|             u = six.text_type(self) |             u = six.text_type(self) | ||||||
|   | |||||||
| @@ -241,7 +241,6 @@ class QuerySet(object): | |||||||
|         aggregate_select = list(self.query.aggregate_select) |         aggregate_select = list(self.query.aggregate_select) | ||||||
|  |  | ||||||
|         only_load = self.query.get_loaded_field_names() |         only_load = self.query.get_loaded_field_names() | ||||||
|         if not fill_cache: |  | ||||||
|         fields = self.model._meta.concrete_fields |         fields = self.model._meta.concrete_fields | ||||||
|  |  | ||||||
|         load_fields = [] |         load_fields = [] | ||||||
| @@ -260,9 +259,6 @@ class QuerySet(object): | |||||||
|                     # Therefore, we need to load all fields from this model |                     # Therefore, we need to load all fields from this model | ||||||
|                     load_fields.append(field.name) |                     load_fields.append(field.name) | ||||||
|  |  | ||||||
|         index_start = len(extra_select) |  | ||||||
|         aggregate_start = index_start + len(load_fields or self.model._meta.concrete_fields) |  | ||||||
|  |  | ||||||
|         skip = None |         skip = None | ||||||
|         if load_fields and not fill_cache: |         if load_fields and not fill_cache: | ||||||
|             # Some fields have been deferred, so we have to initialize |             # Some fields have been deferred, so we have to initialize | ||||||
| @@ -275,30 +271,25 @@ class QuerySet(object): | |||||||
|                 else: |                 else: | ||||||
|                     init_list.append(field.attname) |                     init_list.append(field.attname) | ||||||
|             model_cls = deferred_class_factory(self.model, skip) |             model_cls = deferred_class_factory(self.model, skip) | ||||||
|  |         else: | ||||||
|  |             model_cls = self.model | ||||||
|  |             init_list = [f.attname for f in fields] | ||||||
|  |  | ||||||
|         # Cache db and model outside the loop |         # Cache db and model outside the loop | ||||||
|         db = self.db |         db = self.db | ||||||
|         model = self.model |  | ||||||
|         compiler = self.query.get_compiler(using=db) |         compiler = self.query.get_compiler(using=db) | ||||||
|  |         index_start = len(extra_select) | ||||||
|  |         aggregate_start = index_start + len(init_list) | ||||||
|  |  | ||||||
|         if fill_cache: |         if fill_cache: | ||||||
|             klass_info = get_klass_info(model, max_depth=max_depth, |             klass_info = get_klass_info(model_cls, max_depth=max_depth, | ||||||
|                                         requested=requested, only_load=only_load) |                                         requested=requested, only_load=only_load) | ||||||
|         for row in compiler.results_iter(): |         for row in compiler.results_iter(): | ||||||
|             if fill_cache: |             if fill_cache: | ||||||
|                 obj, _ = get_cached_row(row, index_start, db, klass_info, |                 obj, _ = get_cached_row(row, index_start, db, klass_info, | ||||||
|                                         offset=len(aggregate_select)) |                                         offset=len(aggregate_select)) | ||||||
|             else: |             else: | ||||||
|                 # Omit aggregates in object creation. |                 obj = model_cls.from_db(db, init_list, row[index_start:aggregate_start]) | ||||||
|                 row_data = row[index_start:aggregate_start] |  | ||||||
|                 if skip: |  | ||||||
|                     obj = model_cls(**dict(zip(init_list, row_data))) |  | ||||||
|                 else: |  | ||||||
|                     obj = model(*row_data) |  | ||||||
|  |  | ||||||
|                 # Store the source database of the object |  | ||||||
|                 obj._state.db = db |  | ||||||
|                 # This object came from the database; it's not being added. |  | ||||||
|                 obj._state.adding = False |  | ||||||
|  |  | ||||||
|             if extra_select: |             if extra_select: | ||||||
|                 for i, k in enumerate(extra_select): |                 for i, k in enumerate(extra_select): | ||||||
| @@ -1417,6 +1408,21 @@ def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, | |||||||
|     return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx |     return klass, field_names, field_count, related_fields, reverse_related_fields, pk_idx | ||||||
|  |  | ||||||
|  |  | ||||||
|  | def reorder_for_init(model, field_names, values): | ||||||
|  |     """ | ||||||
|  |     Reorders given field names and values for those fields | ||||||
|  |     to be in the same order as model.__init__() expects to find them. | ||||||
|  |     """ | ||||||
|  |     new_names, new_values = [], [] | ||||||
|  |     for f in model._meta.concrete_fields: | ||||||
|  |         if f.attname not in field_names: | ||||||
|  |             continue | ||||||
|  |         new_names.append(f.attname) | ||||||
|  |         new_values.append(values[field_names.index(f.attname)]) | ||||||
|  |     assert len(new_names) == len(field_names) | ||||||
|  |     return new_names, new_values | ||||||
|  |  | ||||||
|  |  | ||||||
| def get_cached_row(row, index_start, using, klass_info, offset=0, | def get_cached_row(row, index_start, using, klass_info, offset=0, | ||||||
|                    parent_data=()): |                    parent_data=()): | ||||||
|     """ |     """ | ||||||
| @@ -1451,18 +1457,19 @@ def get_cached_row(row, index_start, using, klass_info, offset=0, | |||||||
|          fields[pk_idx] == '')): |          fields[pk_idx] == '')): | ||||||
|         obj = None |         obj = None | ||||||
|     elif field_names: |     elif field_names: | ||||||
|         fields = list(fields) |         values = list(fields) | ||||||
|  |         parent_values = [] | ||||||
|  |         parent_field_names = [] | ||||||
|         for rel_field, value in parent_data: |         for rel_field, value in parent_data: | ||||||
|             field_names.append(rel_field.attname) |             parent_field_names.append(rel_field.attname) | ||||||
|             fields.append(value) |             parent_values.append(value) | ||||||
|         obj = klass(**dict(zip(field_names, fields))) |         field_names, values = reorder_for_init( | ||||||
|  |             klass, parent_field_names + field_names, | ||||||
|  |             parent_values + values) | ||||||
|  |         obj = klass.from_db(using, field_names, values) | ||||||
|     else: |     else: | ||||||
|         obj = klass(*fields) |         field_names = [f.attname for f in klass._meta.concrete_fields] | ||||||
|     # If an object was retrieved, set the database state. |         obj = klass.from_db(using, field_names, fields) | ||||||
|     if obj: |  | ||||||
|         obj._state.db = using |  | ||||||
|         obj._state.adding = False |  | ||||||
|  |  | ||||||
|     # Instantiate related fields |     # Instantiate related fields | ||||||
|     index_end = index_start + field_count + offset |     index_end = index_start + field_count + offset | ||||||
|     # Iterate over each related object, populating any |     # Iterate over each related object, populating any | ||||||
| @@ -1534,15 +1541,18 @@ class RawQuerySet(object): | |||||||
|         self.params = params or () |         self.params = params or () | ||||||
|         self.translations = translations or {} |         self.translations = translations or {} | ||||||
|  |  | ||||||
|     def __iter__(self): |     def resolve_model_init_order(self): | ||||||
|         # Mapping of attrnames to row column positions. Used for constructing |         """ | ||||||
|         # the model using kwargs, needed when not all model's fields are present |         Resolve the init field names and value positions | ||||||
|         # in the query. |         """ | ||||||
|         model_init_field_names = {} |         model_init_names = [f.attname for f in self.model._meta.fields | ||||||
|         # A list of tuples of (column name, column position). Used for |                             if f.attname in self.columns] | ||||||
|         # annotation fields. |         annotation_fields = [(column, pos) for pos, column in enumerate(self.columns) | ||||||
|         annotation_fields = [] |                              if column not in self.model_fields] | ||||||
|  |         model_init_order = [self.columns.index(fname) for fname in model_init_names] | ||||||
|  |         return model_init_names, model_init_order, annotation_fields | ||||||
|  |  | ||||||
|  |     def __iter__(self): | ||||||
|         # Cache some things for performance reasons outside the loop. |         # Cache some things for performance reasons outside the loop. | ||||||
|         db = self.db |         db = self.db | ||||||
|         compiler = connections[db].ops.compiler('SQLCompiler')( |         compiler = connections[db].ops.compiler('SQLCompiler')( | ||||||
| @@ -1553,18 +1563,12 @@ class RawQuerySet(object): | |||||||
|         query = iter(self.query) |         query = iter(self.query) | ||||||
|  |  | ||||||
|         try: |         try: | ||||||
|             # Find out which columns are model's fields, and which ones should be |             model_init_names, model_init_pos, annotation_fields = self.resolve_model_init_order() | ||||||
|             # annotated to the model. |  | ||||||
|             for pos, column in enumerate(self.columns): |  | ||||||
|                 if column in self.model_fields: |  | ||||||
|                     model_init_field_names[self.model_fields[column].attname] = pos |  | ||||||
|                 else: |  | ||||||
|                     annotation_fields.append((column, pos)) |  | ||||||
|  |  | ||||||
|             # Find out which model's fields are not present in the query. |             # Find out which model's fields are not present in the query. | ||||||
|             skip = set() |             skip = set() | ||||||
|             for field in self.model._meta.fields: |             for field in self.model._meta.fields: | ||||||
|                 if field.attname not in model_init_field_names: |                 if field.attname not in model_init_names: | ||||||
|                     skip.add(field.attname) |                     skip.add(field.attname) | ||||||
|             if skip: |             if skip: | ||||||
|                 if self.model._meta.pk.attname in skip: |                 if self.model._meta.pk.attname in skip: | ||||||
| @@ -1572,34 +1576,17 @@ class RawQuerySet(object): | |||||||
|                 model_cls = deferred_class_factory(self.model, skip) |                 model_cls = deferred_class_factory(self.model, skip) | ||||||
|             else: |             else: | ||||||
|                 model_cls = self.model |                 model_cls = self.model | ||||||
|                 # All model's fields are present in the query. So, it is possible |  | ||||||
|                 # to use *args based model instantiation. For each field of the model, |  | ||||||
|                 # record the query column position matching that field. |  | ||||||
|                 model_init_field_pos = [] |  | ||||||
|                 for field in self.model._meta.fields: |  | ||||||
|                     model_init_field_pos.append(model_init_field_names[field.attname]) |  | ||||||
|             if need_resolv_columns: |             if need_resolv_columns: | ||||||
|                 fields = [self.model_fields.get(c, None) for c in self.columns] |                 fields = [self.model_fields.get(c, None) for c in self.columns] | ||||||
|             # Begin looping through the query values. |  | ||||||
|             for values in query: |             for values in query: | ||||||
|                 if need_resolv_columns: |                 if need_resolv_columns: | ||||||
|                     values = compiler.resolve_columns(values, fields) |                     values = compiler.resolve_columns(values, fields) | ||||||
|                 # Associate fields to values |                 # Associate fields to values | ||||||
|                 if skip: |                 model_init_values = [values[pos] for pos in model_init_pos] | ||||||
|                     model_init_kwargs = {} |                 instance = model_cls.from_db(db, model_init_names, model_init_values) | ||||||
|                     for attname, pos in six.iteritems(model_init_field_names): |  | ||||||
|                         model_init_kwargs[attname] = values[pos] |  | ||||||
|                     instance = model_cls(**model_init_kwargs) |  | ||||||
|                 else: |  | ||||||
|                     model_init_args = [values[pos] for pos in model_init_field_pos] |  | ||||||
|                     instance = model_cls(*model_init_args) |  | ||||||
|                 if annotation_fields: |                 if annotation_fields: | ||||||
|                     for column, pos in annotation_fields: |                     for column, pos in annotation_fields: | ||||||
|                         setattr(instance, column, values[pos]) |                         setattr(instance, column, values[pos]) | ||||||
|  |  | ||||||
|                 instance._state.db = db |  | ||||||
|                 instance._state.adding = False |  | ||||||
|  |  | ||||||
|                 yield instance |                 yield instance | ||||||
|         finally: |         finally: | ||||||
|             # Done iterating the Query. If it has its own cursor, close it. |             # Done iterating the Query. If it has its own cursor, close it. | ||||||
|   | |||||||
| @@ -62,6 +62,60 @@ that, you need to :meth:`~Model.save()`. | |||||||
|  |  | ||||||
|         book = Book.objects.create_book("Pride and Prejudice") |         book = Book.objects.create_book("Pride and Prejudice") | ||||||
|  |  | ||||||
|  | Customizing model loading | ||||||
|  | ------------------------- | ||||||
|  |  | ||||||
|  | .. classmethod:: Model.from_db(db, field_names, values) | ||||||
|  |  | ||||||
|  | .. versionadded:: 1.8 | ||||||
|  |  | ||||||
|  | The ``from_db()`` method can be used to customize model instance creation | ||||||
|  | when loading from the database. | ||||||
|  |  | ||||||
|  | The ``db`` argument contains the database alias for the database the model | ||||||
|  | is loaded from, ``field_names`` contains the names of all loaded fields, and | ||||||
|  | ``values`` contains the loaded values for each field in ``field_names``. The | ||||||
|  | ``field_names`` are in the same order as the ``values``, so it is possible to | ||||||
|  | use ``cls(**(zip(field_names, values)))`` to instantiate the object. If all | ||||||
|  | of the model's fields are present, then ``values`` are guaranteed to be in | ||||||
|  | the order ``__init__()`` expects them. That is, the instance can be created | ||||||
|  | by ``cls(*values)``. It is possible to check if all fields are present by | ||||||
|  | consulting ``cls._deferred`` - if ``False``, then all fields have been loaded | ||||||
|  | from the database. | ||||||
|  |  | ||||||
|  | In addition to creating the new model, the ``from_db()`` method must set the | ||||||
|  | ``adding`` and ``db`` flags in the new instance's ``_state`` attribute. | ||||||
|  |  | ||||||
|  | Below is an example showing how torecord the initial values of fields that | ||||||
|  | are loaded from the database::  | ||||||
|  |  | ||||||
|  |     @classmethod | ||||||
|  |     def from_db(cls, db, field_names, values): | ||||||
|  |         # default implementation of from_db() (could be replaced | ||||||
|  |         # with super()) | ||||||
|  |         if cls._deferred: | ||||||
|  |             instance = cls(**zip(field_names, values)) | ||||||
|  |         else: | ||||||
|  |             instance = cls(*values) | ||||||
|  |         instance._state.adding = False | ||||||
|  |         instance._state.db = db | ||||||
|  |         # customization to store the original field values on the instance | ||||||
|  |         instance._loaded_values = zip(field_names, values) | ||||||
|  |         return instance | ||||||
|  |  | ||||||
|  |     def save(self, *args, **kwargs): | ||||||
|  |         # Check how the current values differ from ._loaded_values. For example, | ||||||
|  |         # prevent changing the creator_id of the model. (This example doesn't | ||||||
|  |         # support cases where 'creator_id' is deferred). | ||||||
|  |         if not self._state.adding and ( | ||||||
|  |                 self.creator_id != self._loaded_values['creator_id']): | ||||||
|  |             raise ValueError("Updating the value of creator isn't allowed") | ||||||
|  |         super(...).save(*args, **kwargs) | ||||||
|  |  | ||||||
|  | The example above shows a full ``from_db()`` implementation to clarify how that | ||||||
|  | is done. In this case it would of course be possible to just use ``super()`` call | ||||||
|  | in the ``from_db()`` method. | ||||||
|  |  | ||||||
| .. _validating-objects: | .. _validating-objects: | ||||||
|  |  | ||||||
| Validating objects | Validating objects | ||||||
|   | |||||||
| @@ -193,6 +193,10 @@ Models | |||||||
|   when these objects are unpickled in a different version than the one in |   when these objects are unpickled in a different version than the one in | ||||||
|   which they were pickled. |   which they were pickled. | ||||||
|  |  | ||||||
|  | * Added :meth:`Model.from_db() <django.db.models.Model.from_db()>` which | ||||||
|  |   Django uses whenever objects are loaded using the ORM. The method allows | ||||||
|  |   customizing model loading behavior. | ||||||
|  |  | ||||||
| Signals | Signals | ||||||
| ^^^^^^^ | ^^^^^^^ | ||||||
|  |  | ||||||
|   | |||||||
		Reference in New Issue
	
	Block a user